#include <iostream> #include <vector> #include <algorithm> using namespace std; long long int T[500001][4]; int main() { int m,n; cin >> n >> m; vector < vector <int> >sasiedzi(n+1,vector<int>() ); int x,y; for (int i=0; i<n-1; i++){ cin >> x >> y; sasiedzi[x].push_back(y); sasiedzi[y].push_back(x); T[x][2]++; T[y][2]++; } int r; for (int i=0; i<m; i++){ cin >> r; T[i+1][3] = 1; T[i+1][0] = r; T[i+1][1] = r; } vector <int>liscie; vector <int>rodz; vector <int>rodz2; vector <int>rodz3; for (int i=1; i<n+1; i++){ if (T[i][2] == 1 && T[sasiedzi[i][0]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[i][0]][2]--; T[i][2] = 0; } } for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } long long int suma = 0; rodz3 = rodz; int liczba = n-m; while (liczba > 0){ vector <int> dzie; rodz2.clear(); rodz3 = rodz; for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] == 0){ dzie.push_back(sasiedzi[rodz[i]][j]); } } int pocz[dzie.size()]; int kon[dzie.size()]; int all[2*dzie.size()]; for (int j=0; j<dzie.size(); j++){ pocz[j] = T[dzie[j]][0]; kon[j] = T[dzie[j]][1]; all[2*j] = T[dzie[j]][0]; all[2*j+1] = T[dzie[j]][1]; } sort(&pocz[0],&pocz[dzie.size()]); sort(&kon[0],&kon[dzie.size()]); sort(&all[0],&all[2*dzie.size()]); int p,k; int mn[2*dzie.size()][2]; int wsp = 0; int wsk = 0; for (int j=0; j<2*dzie.size(); j++){ while(wsk < dzie.size() && kon[wsk] <= all[j]) wsk++; mn[j][0] = wsk; } for (int j=2*dzie.size()-1; j>=0; j--){ while (wsp < dzie.size() && pocz[dzie.size()-1-wsp] >= all[j]) wsp++; mn[j][1] = wsp; } p = 0; k = 2*dzie.size()-1; while(k-p>1){ if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] < mn[k][1]) p++; if (mn[p][0] == mn[k][1]){ k--; p++; } } if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] > mn[k][1]) p++; T[rodz[i]][0] = all[p]; T[rodz[i]][1] = all[k]; T[rodz[i]][2] = 1; for (int j=0; j<dzie.size(); j++){ if (!(all[p] >= T[dzie[j]][0] && all[p] <= T[dzie[j]][1])){ suma += min(abs(all[p] - T[dzie[j]][0]),abs(all[p] - T[dzie[j]][1])); } } dzie.clear(); } liczba -= rodz.size(); for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[rodz[i]][j]][2]--; } } } for (int j=0; j<rodz.size(); j++) T[rodz[j]][2] = 0; rodz.clear(); for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } rodz = rodz2; } if (rodz3.size() > 1){ int x1 = rodz3[0]; int x2 = rodz3[1]; if (T[x1][0] > T[x2][1]) suma += T[x1][0] - T[x2][1]; if (T[x2][0] > T[x1][1]) suma += T[x2][0] - T[x1][1]; } cout << suma; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | #include <iostream> #include <vector> #include <algorithm> using namespace std; long long int T[500001][4]; int main() { int m,n; cin >> n >> m; vector < vector <int> >sasiedzi(n+1,vector<int>() ); int x,y; for (int i=0; i<n-1; i++){ cin >> x >> y; sasiedzi[x].push_back(y); sasiedzi[y].push_back(x); T[x][2]++; T[y][2]++; } int r; for (int i=0; i<m; i++){ cin >> r; T[i+1][3] = 1; T[i+1][0] = r; T[i+1][1] = r; } vector <int>liscie; vector <int>rodz; vector <int>rodz2; vector <int>rodz3; for (int i=1; i<n+1; i++){ if (T[i][2] == 1 && T[sasiedzi[i][0]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[i][0]][2]--; T[i][2] = 0; } } for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } long long int suma = 0; rodz3 = rodz; int liczba = n-m; while (liczba > 0){ vector <int> dzie; rodz2.clear(); rodz3 = rodz; for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] == 0){ dzie.push_back(sasiedzi[rodz[i]][j]); } } int pocz[dzie.size()]; int kon[dzie.size()]; int all[2*dzie.size()]; for (int j=0; j<dzie.size(); j++){ pocz[j] = T[dzie[j]][0]; kon[j] = T[dzie[j]][1]; all[2*j] = T[dzie[j]][0]; all[2*j+1] = T[dzie[j]][1]; } sort(&pocz[0],&pocz[dzie.size()]); sort(&kon[0],&kon[dzie.size()]); sort(&all[0],&all[2*dzie.size()]); int p,k; int mn[2*dzie.size()][2]; int wsp = 0; int wsk = 0; for (int j=0; j<2*dzie.size(); j++){ while(wsk < dzie.size() && kon[wsk] <= all[j]) wsk++; mn[j][0] = wsk; } for (int j=2*dzie.size()-1; j>=0; j--){ while (wsp < dzie.size() && pocz[dzie.size()-1-wsp] >= all[j]) wsp++; mn[j][1] = wsp; } p = 0; k = 2*dzie.size()-1; while(k-p>1){ if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] < mn[k][1]) p++; if (mn[p][0] == mn[k][1]){ k--; p++; } } if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] > mn[k][1]) p++; T[rodz[i]][0] = all[p]; T[rodz[i]][1] = all[k]; T[rodz[i]][2] = 1; for (int j=0; j<dzie.size(); j++){ if (!(all[p] >= T[dzie[j]][0] && all[p] <= T[dzie[j]][1])){ suma += min(abs(all[p] - T[dzie[j]][0]),abs(all[p] - T[dzie[j]][1])); } } dzie.clear(); } liczba -= rodz.size(); for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[rodz[i]][j]][2]--; } } } for (int j=0; j<rodz.size(); j++) T[rodz[j]][2] = 0; rodz.clear(); for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } rodz = rodz2; } if (rodz3.size() > 1){ int x1 = rodz3[0]; int x2 = rodz3[1]; if (T[x1][0] > T[x2][1]) suma += T[x1][0] - T[x2][1]; if (T[x2][0] > T[x1][1]) suma += T[x2][0] - T[x1][1]; } cout << suma; } |