#include<cstdio> #include<algorithm> #include<vector> #include<set> #include<map> #include<queue> #include<cmath> #include<iostream> #include<string> using namespace std; #define F first #define S second #define MP make_pair #define PB push_back #define LL long long #define PII pair<int, int> #define PLL pair<LL, LL> int n, k; LL res; vector<int> V[1000005]; PII T[1000005]; bool vis[1000005]; vector<PII> X; void DFS(int v, int prev) { vis[v]=1; if((int)V[v].size()==1) return; for(int i=0; i<(int)V[v].size(); i++) if(!vis[V[v][i]]) DFS(V[v][i], v); X.clear(); LL r=1000000000000000000LL, sum=0; int le=0, mid=0, ri=0, pre=0; T[v].F=1000000; //printf("%d:\n", v); for(int i=0; i<(int)V[v].size(); i++) { if(V[v][i]!=prev) { sum+=T[V[v][i]].F; ri++; //printf("%d %d\n", T[V[v][i]].F, T[V[v][i]].S); X.PB(MP(T[V[v][i]].F, 0)); X.PB(MP(T[V[v][i]].S, 1)); } } sort(X.begin(), X.end()); for(int i=0; i<(int)X.size(); i++) { sum+=(LL)le*(X[i].F-pre); sum-=(LL)ri*(X[i].F-pre); pre=X[i].F; if(le==ri) { T[v].F=min(T[v].F, X[i].F); T[v].S=max(T[v].S, X[i].F); r=min(r, sum); } if(X[i].F==0) { ri--; mid++; } else { mid--; le++; } if(le==ri) { T[v].F=min(T[v].F, X[i].F); T[v].S=max(T[v].S, X[i].F); r=min(r, sum); } } //printf("%d %lld\n", v, r); res+=r; } int main() { //ios_base::sync_with_stdio(0); scanf("%d%d", &n, &k); for(int i=1; i<n; i++) { int a, b; scanf("%d%d", &a, &b); V[a].PB(b); V[b].PB(a); } for(int i=1; i<=k; i++) { int a; scanf("%d", &a); T[i]=MP(a, a); } if(n==2) { if(k==1) printf("0\n"); else printf("%d\n", abs(T[2].F-T[1].F)); return 0; } DFS(n, n); printf("%lld\n", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | #include<cstdio> #include<algorithm> #include<vector> #include<set> #include<map> #include<queue> #include<cmath> #include<iostream> #include<string> using namespace std; #define F first #define S second #define MP make_pair #define PB push_back #define LL long long #define PII pair<int, int> #define PLL pair<LL, LL> int n, k; LL res; vector<int> V[1000005]; PII T[1000005]; bool vis[1000005]; vector<PII> X; void DFS(int v, int prev) { vis[v]=1; if((int)V[v].size()==1) return; for(int i=0; i<(int)V[v].size(); i++) if(!vis[V[v][i]]) DFS(V[v][i], v); X.clear(); LL r=1000000000000000000LL, sum=0; int le=0, mid=0, ri=0, pre=0; T[v].F=1000000; //printf("%d:\n", v); for(int i=0; i<(int)V[v].size(); i++) { if(V[v][i]!=prev) { sum+=T[V[v][i]].F; ri++; //printf("%d %d\n", T[V[v][i]].F, T[V[v][i]].S); X.PB(MP(T[V[v][i]].F, 0)); X.PB(MP(T[V[v][i]].S, 1)); } } sort(X.begin(), X.end()); for(int i=0; i<(int)X.size(); i++) { sum+=(LL)le*(X[i].F-pre); sum-=(LL)ri*(X[i].F-pre); pre=X[i].F; if(le==ri) { T[v].F=min(T[v].F, X[i].F); T[v].S=max(T[v].S, X[i].F); r=min(r, sum); } if(X[i].F==0) { ri--; mid++; } else { mid--; le++; } if(le==ri) { T[v].F=min(T[v].F, X[i].F); T[v].S=max(T[v].S, X[i].F); r=min(r, sum); } } //printf("%d %lld\n", v, r); res+=r; } int main() { //ios_base::sync_with_stdio(0); scanf("%d%d", &n, &k); for(int i=1; i<n; i++) { int a, b; scanf("%d%d", &a, &b); V[a].PB(b); V[b].PB(a); } for(int i=1; i<=k; i++) { int a; scanf("%d", &a); T[i]=MP(a, a); } if(n==2) { if(k==1) printf("0\n"); else printf("%d\n", abs(T[2].F-T[1].F)); return 0; } DFS(n, n); printf("%lld\n", res); return 0; } |