#include <bits/stdc++.h> using namespace std; #define REP(i,a,b) for(int i=(a);i<=(b);++i) #define FORI(i,n) REP(i,1,n) #define FOR(i,n) REP(i,0,int(n)-1) #define mp make_pair #define pb push_back #define pii pair<int,int> #define vi vector<int> #define ll long long #define SZ(x) int((x).size()) #define DBG(v) cerr << #v << " = " << (v) << endl; #define FOREACH(i,t) for (typeof(t.begin()) i=t.begin(); i!=t.end(); i++) #define SORT(X) sort(X.begin(),X.end()) #define fi first #define se second vector<int> V[500500]; int D[500500],G[500500]; vector<int> Mg[500500]; int K[500500]; int m; int Ocalc[500500]; void calc(int u){ Ocalc[u] = 1; if(u <= m){ D[u] = G[u] = K[u]; return; } for(int v : V[u]){ if(!Ocalc[v]){ calc(v); Mg[u].pb(D[v]); Mg[u].pb(G[v]); } } //SORT(Mg[u]); nth_element(Mg[u].begin(), Mg[u].begin()+(SZ(Mg[u])-1)/2,Mg[u].end()); D[u] = Mg[u][(SZ(Mg[u])-1)/2]; nth_element(Mg[u].begin(),Mg[u].begin()+SZ(Mg[u])/2,Mg[u].end()); G[u] = Mg[u][SZ(Mg[u])/2]; } long long ans; int Os[500500]; void sett(int u, int w){ int my = -123; if(D[u] <= w && w <= G[u]) my = w; if(w < D[u]) my = D[u]; if(G[u] < w) my = G[u]; ans += abs(my-w); Os[u] = 1; for(int v : V[u]){ if(Os[v] == 0){ sett(v,my); } } } int main () { int n; scanf("%d%d",&n,&m); FOR(i,n-1){ int a,b; scanf("%d%d",&a,&b); V[a].pb(b); V[b].pb(a); } FORI(i,m) scanf("%d", &K[i]); calc(m+1); sett(m+1,0); printf("%lld\n", ans - D[m+1]); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #include <bits/stdc++.h> using namespace std; #define REP(i,a,b) for(int i=(a);i<=(b);++i) #define FORI(i,n) REP(i,1,n) #define FOR(i,n) REP(i,0,int(n)-1) #define mp make_pair #define pb push_back #define pii pair<int,int> #define vi vector<int> #define ll long long #define SZ(x) int((x).size()) #define DBG(v) cerr << #v << " = " << (v) << endl; #define FOREACH(i,t) for (typeof(t.begin()) i=t.begin(); i!=t.end(); i++) #define SORT(X) sort(X.begin(),X.end()) #define fi first #define se second vector<int> V[500500]; int D[500500],G[500500]; vector<int> Mg[500500]; int K[500500]; int m; int Ocalc[500500]; void calc(int u){ Ocalc[u] = 1; if(u <= m){ D[u] = G[u] = K[u]; return; } for(int v : V[u]){ if(!Ocalc[v]){ calc(v); Mg[u].pb(D[v]); Mg[u].pb(G[v]); } } //SORT(Mg[u]); nth_element(Mg[u].begin(), Mg[u].begin()+(SZ(Mg[u])-1)/2,Mg[u].end()); D[u] = Mg[u][(SZ(Mg[u])-1)/2]; nth_element(Mg[u].begin(),Mg[u].begin()+SZ(Mg[u])/2,Mg[u].end()); G[u] = Mg[u][SZ(Mg[u])/2]; } long long ans; int Os[500500]; void sett(int u, int w){ int my = -123; if(D[u] <= w && w <= G[u]) my = w; if(w < D[u]) my = D[u]; if(G[u] < w) my = G[u]; ans += abs(my-w); Os[u] = 1; for(int v : V[u]){ if(Os[v] == 0){ sett(v,my); } } } int main () { int n; scanf("%d%d",&n,&m); FOR(i,n-1){ int a,b; scanf("%d%d",&a,&b); V[a].pb(b); V[b].pb(a); } FORI(i,m) scanf("%d", &K[i]); calc(m+1); sett(m+1,0); printf("%lld\n", ans - D[m+1]); } |