#include<bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i = (a); i <= (b); ++i)
#define FORD(i,a,b) for(int i = (a); i >= (b); --i)
#define RI(i,n) FOR(i,1,(n))
#define REP(i,n) FOR(i,0,(n)-1)
#define mini(a,b) a=min(a,b)
#define maxi(a,b) a=max(a,b)
#define mp make_pair
#define pb push_back
#define st first
#define nd second
#define sz(w) (int) w.size()
typedef vector<int> vi;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> pii;
const int inf = 1e9 + 5;
const int nax = 5e5 + 5;
vi w[nax];
int val[nax];
ll RESULT;
ll f(int x, const vector<pii> & L) {
ll s = 0;
for(const pii & p : L) {
if(x <= p.st) s += p.st - x;
else if(x >= p.nd) s += x - p.nd;
}
return s;
}
pii dfs(int a, int par = -1) {
if(sz(w[a]) == 1) return mp(val[a], val[a]);
vector<pii> L;
for(int b : w[a]) if(b != par)
L.pb(dfs(b, a));
int low = 1, high = 500 * 1000;
while(low < high) {
int med = (low + high) / 2;
if(f(med, L) > f(med+1, L)) low = med + 1;
else high = med;
}
int start = low;
ll MIN = f(start, L);
high = 500 * 1000;
while(low < high) {
int med = (low + high + 1) / 2;
if(f(med, L) == MIN) low = med;
else high = med - 1;
}
RESULT += MIN;
return mp(start, low);
}
int main() {
int n, m;
scanf("%d%d", &n, &m);
REP(_, n - 1) {
int a, b;
scanf("%d%d", &a, &b);
w[a].pb(b);
w[b].pb(a);
}
RI(i, m) scanf("%d", &val[i]);
if(n == 2) {
printf("%d\n", abs(val[1] - val[2]));
return 0;
}
int root = 1;
while(sz(w[root]) == 1) ++root;
dfs(root);
printf("%lld\n", RESULT);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | #include<bits/stdc++.h> using namespace std; #define FOR(i,a,b) for(int i = (a); i <= (b); ++i) #define FORD(i,a,b) for(int i = (a); i >= (b); --i) #define RI(i,n) FOR(i,1,(n)) #define REP(i,n) FOR(i,0,(n)-1) #define mini(a,b) a=min(a,b) #define maxi(a,b) a=max(a,b) #define mp make_pair #define pb push_back #define st first #define nd second #define sz(w) (int) w.size() typedef vector<int> vi; typedef long long ll; typedef long double ld; typedef pair<int,int> pii; const int inf = 1e9 + 5; const int nax = 5e5 + 5; vi w[nax]; int val[nax]; ll RESULT; ll f(int x, const vector<pii> & L) { ll s = 0; for(const pii & p : L) { if(x <= p.st) s += p.st - x; else if(x >= p.nd) s += x - p.nd; } return s; } pii dfs(int a, int par = -1) { if(sz(w[a]) == 1) return mp(val[a], val[a]); vector<pii> L; for(int b : w[a]) if(b != par) L.pb(dfs(b, a)); int low = 1, high = 500 * 1000; while(low < high) { int med = (low + high) / 2; if(f(med, L) > f(med+1, L)) low = med + 1; else high = med; } int start = low; ll MIN = f(start, L); high = 500 * 1000; while(low < high) { int med = (low + high + 1) / 2; if(f(med, L) == MIN) low = med; else high = med - 1; } RESULT += MIN; return mp(start, low); } int main() { int n, m; scanf("%d%d", &n, &m); REP(_, n - 1) { int a, b; scanf("%d%d", &a, &b); w[a].pb(b); w[b].pb(a); } RI(i, m) scanf("%d", &val[i]); if(n == 2) { printf("%d\n", abs(val[1] - val[2])); return 0; } int root = 1; while(sz(w[root]) == 1) ++root; dfs(root); printf("%lld\n", RESULT); return 0; } |
polski