#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
long long int T[500001][4];
int main()
{
int m,n;
cin >> n >> m;
vector < vector <int> >sasiedzi(n+1,vector<int>() );
int x,y;
for (int i=0; i<n-1; i++){
cin >> x >> y;
sasiedzi[x].push_back(y);
sasiedzi[y].push_back(x);
T[x][2]++;
T[y][2]++;
}
int r;
for (int i=0; i<m; i++){
cin >> r;
T[i+1][3] = 1;
T[i+1][0] = r;
T[i+1][1] = r;
}
vector <int>liscie;
vector <int>rodz;
vector <int>rodz2;
vector <int>rodz3;
for (int i=1; i<n+1; i++){
if (T[i][2] == 1 && T[sasiedzi[i][0]][2] > 1){
rodz2.push_back(sasiedzi[i][0]);
T[sasiedzi[i][0]][2]--;
T[i][2] = 0;
}
}
for (int i=0; i<rodz2.size(); i++){
if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){
rodz.push_back(rodz2[i]);
T[rodz2[i]][3] = 1;
}
}
long long int suma = 0;
rodz3 = rodz;
int liczba = n-m;
while (liczba > 0){
vector <int> dzie;
rodz2.clear();
rodz3 = rodz;
for (int i=0; i<rodz.size(); i++){
for (int j=0; j<sasiedzi[rodz[i]].size(); j++){
if (T[sasiedzi[rodz[i]][j]][2] == 0){
dzie.push_back(sasiedzi[rodz[i]][j]);
}
}
int pocz[dzie.size()];
int kon[dzie.size()];
int all[2*dzie.size()];
for (int j=0; j<dzie.size(); j++){
pocz[j] = T[dzie[j]][0];
kon[j] = T[dzie[j]][1];
all[2*j] = T[dzie[j]][0];
all[2*j+1] = T[dzie[j]][1];
}
sort(&pocz[0],&pocz[dzie.size()]);
sort(&kon[0],&kon[dzie.size()]);
sort(&all[0],&all[2*dzie.size()]);
int p,k;
int mn[2*dzie.size()][2];
int wsp = 0;
int wsk = 0;
for (int j=0; j<2*dzie.size(); j++){
while(wsk < dzie.size() && kon[wsk] <= all[j]) wsk++;
mn[j][0] = wsk;
}
for (int j=2*dzie.size()-1; j>=0; j--){
while (wsp < dzie.size() && pocz[dzie.size()-1-wsp] >= all[j]) wsp++;
mn[j][1] = wsp;
}
p = 0;
k = 2*dzie.size()-1;
while(k-p>1){
if (mn[p][0] > mn[k][1]) k--;
if (mn[p][0] < mn[k][1]) p++;
if (mn[p][0] == mn[k][1]){
k--;
p++;
}
}
if (mn[p][0] > mn[k][1]) k--;
if (mn[p][0] > mn[k][1]) p++;
T[rodz[i]][0] = all[p];
T[rodz[i]][1] = all[k];
T[rodz[i]][2] = 1;
for (int j=0; j<dzie.size(); j++){
if (!(all[p] >= T[dzie[j]][0] && all[p] <= T[dzie[j]][1])){
suma += min(abs(all[p] - T[dzie[j]][0]),abs(all[p] - T[dzie[j]][1]));
}
}
dzie.clear();
}
liczba -= rodz.size();
for (int i=0; i<rodz.size(); i++){
for (int j=0; j<sasiedzi[rodz[i]].size(); j++){
if (T[sasiedzi[rodz[i]][j]][2] > 1){
rodz2.push_back(sasiedzi[i][0]);
T[sasiedzi[rodz[i]][j]][2]--;
}
}
}
for (int j=0; j<rodz.size(); j++) T[rodz[j]][2] = 0;
rodz.clear();
for (int i=0; i<rodz2.size(); i++){
if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){
rodz.push_back(rodz2[i]);
T[rodz2[i]][3] = 1;
}
}
rodz = rodz2;
}
if (rodz3.size() > 1){
int x1 = rodz3[0];
int x2 = rodz3[1];
if (T[x1][0] > T[x2][1]) suma += T[x1][0] - T[x2][1];
if (T[x2][0] > T[x1][1]) suma += T[x2][0] - T[x1][1];
}
cout << suma;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | #include <iostream> #include <vector> #include <algorithm> using namespace std; long long int T[500001][4]; int main() { int m,n; cin >> n >> m; vector < vector <int> >sasiedzi(n+1,vector<int>() ); int x,y; for (int i=0; i<n-1; i++){ cin >> x >> y; sasiedzi[x].push_back(y); sasiedzi[y].push_back(x); T[x][2]++; T[y][2]++; } int r; for (int i=0; i<m; i++){ cin >> r; T[i+1][3] = 1; T[i+1][0] = r; T[i+1][1] = r; } vector <int>liscie; vector <int>rodz; vector <int>rodz2; vector <int>rodz3; for (int i=1; i<n+1; i++){ if (T[i][2] == 1 && T[sasiedzi[i][0]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[i][0]][2]--; T[i][2] = 0; } } for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } long long int suma = 0; rodz3 = rodz; int liczba = n-m; while (liczba > 0){ vector <int> dzie; rodz2.clear(); rodz3 = rodz; for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] == 0){ dzie.push_back(sasiedzi[rodz[i]][j]); } } int pocz[dzie.size()]; int kon[dzie.size()]; int all[2*dzie.size()]; for (int j=0; j<dzie.size(); j++){ pocz[j] = T[dzie[j]][0]; kon[j] = T[dzie[j]][1]; all[2*j] = T[dzie[j]][0]; all[2*j+1] = T[dzie[j]][1]; } sort(&pocz[0],&pocz[dzie.size()]); sort(&kon[0],&kon[dzie.size()]); sort(&all[0],&all[2*dzie.size()]); int p,k; int mn[2*dzie.size()][2]; int wsp = 0; int wsk = 0; for (int j=0; j<2*dzie.size(); j++){ while(wsk < dzie.size() && kon[wsk] <= all[j]) wsk++; mn[j][0] = wsk; } for (int j=2*dzie.size()-1; j>=0; j--){ while (wsp < dzie.size() && pocz[dzie.size()-1-wsp] >= all[j]) wsp++; mn[j][1] = wsp; } p = 0; k = 2*dzie.size()-1; while(k-p>1){ if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] < mn[k][1]) p++; if (mn[p][0] == mn[k][1]){ k--; p++; } } if (mn[p][0] > mn[k][1]) k--; if (mn[p][0] > mn[k][1]) p++; T[rodz[i]][0] = all[p]; T[rodz[i]][1] = all[k]; T[rodz[i]][2] = 1; for (int j=0; j<dzie.size(); j++){ if (!(all[p] >= T[dzie[j]][0] && all[p] <= T[dzie[j]][1])){ suma += min(abs(all[p] - T[dzie[j]][0]),abs(all[p] - T[dzie[j]][1])); } } dzie.clear(); } liczba -= rodz.size(); for (int i=0; i<rodz.size(); i++){ for (int j=0; j<sasiedzi[rodz[i]].size(); j++){ if (T[sasiedzi[rodz[i]][j]][2] > 1){ rodz2.push_back(sasiedzi[i][0]); T[sasiedzi[rodz[i]][j]][2]--; } } } for (int j=0; j<rodz.size(); j++) T[rodz[j]][2] = 0; rodz.clear(); for (int i=0; i<rodz2.size(); i++){ if (T[rodz2[i]][2] == 1 && T[rodz2[i]][3] == 0){ rodz.push_back(rodz2[i]); T[rodz2[i]][3] = 1; } } rodz = rodz2; } if (rodz3.size() > 1){ int x1 = rodz3[0]; int x2 = rodz3[1]; if (T[x1][0] > T[x2][1]) suma += T[x1][0] - T[x2][1]; if (T[x2][0] > T[x1][1]) suma += T[x2][0] - T[x1][1]; } cout << suma; } |
English