#include<bits/stdc++.h> using namespace std; #define FOR(i,a,b) for(int i = (a); i <= (b); ++i) #define FORD(i,a,b) for(int i = (a); i >= (b); --i) #define RI(i,n) FOR(i,1,(n)) #define REP(i,n) FOR(i,0,(n)-1) #define mini(a,b) a=min(a,b) #define maxi(a,b) a=max(a,b) #define mp make_pair #define pb push_back #define st first #define nd second #define sz(w) (int) w.size() typedef vector<int> vi; typedef long long ll; typedef long double ld; typedef pair<int,int> pii; const int inf = 1e9 + 5; const int nax = 5e5 + 5; vi w[nax]; int val[nax]; ll RESULT; ll f(int x, const vector<pii> & L) { ll s = 0; for(const pii & p : L) { if(x <= p.st) s += p.st - x; else if(x >= p.nd) s += x - p.nd; } return s; } pii dfs(int a, int par = -1) { if(sz(w[a]) == 1) return mp(val[a], val[a]); vector<pii> L; for(int b : w[a]) if(b != par) L.pb(dfs(b, a)); int low = 1, high = 500 * 1000; while(low < high) { int med = (low + high) / 2; if(f(med, L) > f(med+1, L)) low = med + 1; else high = med; } int start = low; ll MIN = f(start, L); high = 500 * 1000; while(low < high) { int med = (low + high + 1) / 2; if(f(med, L) == MIN) low = med; else high = med - 1; } RESULT += MIN; return mp(start, low); } int main() { int n, m; scanf("%d%d", &n, &m); REP(_, n - 1) { int a, b; scanf("%d%d", &a, &b); w[a].pb(b); w[b].pb(a); } RI(i, m) scanf("%d", &val[i]); if(n == 2) { printf("%d\n", abs(val[1] - val[2])); return 0; } int root = 1; while(sz(w[root]) == 1) ++root; dfs(root); printf("%lld\n", RESULT); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | #include<bits/stdc++.h> using namespace std; #define FOR(i,a,b) for(int i = (a); i <= (b); ++i) #define FORD(i,a,b) for(int i = (a); i >= (b); --i) #define RI(i,n) FOR(i,1,(n)) #define REP(i,n) FOR(i,0,(n)-1) #define mini(a,b) a=min(a,b) #define maxi(a,b) a=max(a,b) #define mp make_pair #define pb push_back #define st first #define nd second #define sz(w) (int) w.size() typedef vector<int> vi; typedef long long ll; typedef long double ld; typedef pair<int,int> pii; const int inf = 1e9 + 5; const int nax = 5e5 + 5; vi w[nax]; int val[nax]; ll RESULT; ll f(int x, const vector<pii> & L) { ll s = 0; for(const pii & p : L) { if(x <= p.st) s += p.st - x; else if(x >= p.nd) s += x - p.nd; } return s; } pii dfs(int a, int par = -1) { if(sz(w[a]) == 1) return mp(val[a], val[a]); vector<pii> L; for(int b : w[a]) if(b != par) L.pb(dfs(b, a)); int low = 1, high = 500 * 1000; while(low < high) { int med = (low + high) / 2; if(f(med, L) > f(med+1, L)) low = med + 1; else high = med; } int start = low; ll MIN = f(start, L); high = 500 * 1000; while(low < high) { int med = (low + high + 1) / 2; if(f(med, L) == MIN) low = med; else high = med - 1; } RESULT += MIN; return mp(start, low); } int main() { int n, m; scanf("%d%d", &n, &m); REP(_, n - 1) { int a, b; scanf("%d%d", &a, &b); w[a].pb(b); w[b].pb(a); } RI(i, m) scanf("%d", &val[i]); if(n == 2) { printf("%d\n", abs(val[1] - val[2])); return 0; } int root = 1; while(sz(w[root]) == 1) ++root; dfs(root); printf("%lld\n", RESULT); return 0; } |