#include <bits/stdc++.h>
using namespace std;
#define REP(i,a,b) for(int i=(a);i<=(b);++i)
#define FORI(i,n) REP(i,1,n)
#define FOR(i,n) REP(i,0,int(n)-1)
#define mp make_pair
#define pb push_back
#define pii pair<int,int>
#define vi vector<int>
#define ll long long
#define SZ(x) int((x).size())
#define DBG(v) cerr << #v << " = " << (v) << endl;
#define FOREACH(i,t) for (typeof(t.begin()) i=t.begin(); i!=t.end(); i++)
#define SORT(X) sort(X.begin(),X.end())
#define fi first
#define se second
vector<int> V[500500];
int D[500500],G[500500];
vector<int> Mg[500500];
int K[500500];
int m;
int Ocalc[500500];
void calc(int u){
Ocalc[u] = 1;
if(u <= m){
D[u] = G[u] = K[u];
return;
}
for(int v : V[u]){
if(!Ocalc[v]){
calc(v);
Mg[u].pb(D[v]);
Mg[u].pb(G[v]);
}
}
//SORT(Mg[u]);
nth_element(Mg[u].begin(), Mg[u].begin()+(SZ(Mg[u])-1)/2,Mg[u].end());
D[u] = Mg[u][(SZ(Mg[u])-1)/2];
nth_element(Mg[u].begin(),Mg[u].begin()+SZ(Mg[u])/2,Mg[u].end());
G[u] = Mg[u][SZ(Mg[u])/2];
}
long long ans;
int Os[500500];
void sett(int u, int w){
int my = -123;
if(D[u] <= w && w <= G[u]) my = w;
if(w < D[u]) my = D[u];
if(G[u] < w) my = G[u];
ans += abs(my-w);
Os[u] = 1;
for(int v : V[u]){
if(Os[v] == 0){
sett(v,my);
}
}
}
int main () {
int n;
scanf("%d%d",&n,&m);
FOR(i,n-1){
int a,b;
scanf("%d%d",&a,&b);
V[a].pb(b);
V[b].pb(a);
}
FORI(i,m) scanf("%d", &K[i]);
calc(m+1);
sett(m+1,0);
printf("%lld\n", ans - D[m+1]);
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #include <bits/stdc++.h> using namespace std; #define REP(i,a,b) for(int i=(a);i<=(b);++i) #define FORI(i,n) REP(i,1,n) #define FOR(i,n) REP(i,0,int(n)-1) #define mp make_pair #define pb push_back #define pii pair<int,int> #define vi vector<int> #define ll long long #define SZ(x) int((x).size()) #define DBG(v) cerr << #v << " = " << (v) << endl; #define FOREACH(i,t) for (typeof(t.begin()) i=t.begin(); i!=t.end(); i++) #define SORT(X) sort(X.begin(),X.end()) #define fi first #define se second vector<int> V[500500]; int D[500500],G[500500]; vector<int> Mg[500500]; int K[500500]; int m; int Ocalc[500500]; void calc(int u){ Ocalc[u] = 1; if(u <= m){ D[u] = G[u] = K[u]; return; } for(int v : V[u]){ if(!Ocalc[v]){ calc(v); Mg[u].pb(D[v]); Mg[u].pb(G[v]); } } //SORT(Mg[u]); nth_element(Mg[u].begin(), Mg[u].begin()+(SZ(Mg[u])-1)/2,Mg[u].end()); D[u] = Mg[u][(SZ(Mg[u])-1)/2]; nth_element(Mg[u].begin(),Mg[u].begin()+SZ(Mg[u])/2,Mg[u].end()); G[u] = Mg[u][SZ(Mg[u])/2]; } long long ans; int Os[500500]; void sett(int u, int w){ int my = -123; if(D[u] <= w && w <= G[u]) my = w; if(w < D[u]) my = D[u]; if(G[u] < w) my = G[u]; ans += abs(my-w); Os[u] = 1; for(int v : V[u]){ if(Os[v] == 0){ sett(v,my); } } } int main () { int n; scanf("%d%d",&n,&m); FOR(i,n-1){ int a,b; scanf("%d%d",&a,&b); V[a].pb(b); V[b].pb(a); } FORI(i,m) scanf("%d", &K[i]); calc(m+1); sett(m+1,0); printf("%lld\n", ans - D[m+1]); } |
English