#include <bits/stdc++.h> using namespace std; typedef long long LL; template<typename TH> void debug_vars(const char* data, TH head){ cerr << data << "=" << head << "\n"; } template<typename TH, typename... TA> void debug_vars(const char* data, TH head, TA... tail){ while(*data != ',') cerr << *data++; cerr << "=" << head << ","; debug_vars(data+1, tail...); } #ifdef LOCAL #define debug(...) debug_vars(#__VA_ARGS__, __VA_ARGS__) #else #define debug(...) (__VA_ARGS__) #endif ///////////////////////////////////////////////////////// const int MaxN = 500005; vector<int> adj[MaxN]; int value[MaxN]; bool visited[MaxN]; int N, M; vector<int> modifiers[MaxN]; LL increase[MaxN]; void input(){ scanf("%d%d", &N, &M); for(int i = 0; i < N-1; i++){ int u, v; scanf("%d%d", &u, &v); adj[u].push_back(v); adj[v].push_back(u); } for(int i = 1; i <= M; i++){ scanf("%d", &value[i]); } } void dfs(int v){ visited[v] = true; increase[v] = 0; modifiers[v].clear(); if(v <= M){ modifiers[v] = {value[v], value[v]}; return; } for(int s : adj[v]){ if(visited[s]) continue; dfs(s); increase[v] += increase[s]; assert(modifiers[s].size() == 2); for(int x : modifiers[s]) modifiers[v].push_back(x); } sort(modifiers[v].begin(), modifiers[v].end()); int pos1 = (int)modifiers[v].size() / 2 - 1, pos2 = pos1+1, mid1 = modifiers[v][pos1], mid2 = modifiers[v][pos2]; LL prevSumDist = 0; for(int x : modifiers[v]) prevSumDist += abs(x-mid1); LL nowSumDist = abs(mid1-mid2); increase[v] += prevSumDist-nowSumDist; modifiers[v].clear(); modifiers[v].shrink_to_fit(); modifiers[v] = {mid1, mid2}; // debug(v); // for(int x : modifiers[v]) debug(x); // debug(increase[v]); } int main(){ input(); if(N == 2){ printf("%d\n", abs(value[1]-value[2])); return 0; } assert(N > M); dfs(N); assert(modifiers[N].size() == 2); printf("%lld\n", (abs(modifiers[N][0]-modifiers[N][1]) + increase[N])/2); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | #include <bits/stdc++.h> using namespace std; typedef long long LL; template<typename TH> void debug_vars(const char* data, TH head){ cerr << data << "=" << head << "\n"; } template<typename TH, typename... TA> void debug_vars(const char* data, TH head, TA... tail){ while(*data != ',') cerr << *data++; cerr << "=" << head << ","; debug_vars(data+1, tail...); } #ifdef LOCAL #define debug(...) debug_vars(#__VA_ARGS__, __VA_ARGS__) #else #define debug(...) (__VA_ARGS__) #endif ///////////////////////////////////////////////////////// const int MaxN = 500005; vector<int> adj[MaxN]; int value[MaxN]; bool visited[MaxN]; int N, M; vector<int> modifiers[MaxN]; LL increase[MaxN]; void input(){ scanf("%d%d", &N, &M); for(int i = 0; i < N-1; i++){ int u, v; scanf("%d%d", &u, &v); adj[u].push_back(v); adj[v].push_back(u); } for(int i = 1; i <= M; i++){ scanf("%d", &value[i]); } } void dfs(int v){ visited[v] = true; increase[v] = 0; modifiers[v].clear(); if(v <= M){ modifiers[v] = {value[v], value[v]}; return; } for(int s : adj[v]){ if(visited[s]) continue; dfs(s); increase[v] += increase[s]; assert(modifiers[s].size() == 2); for(int x : modifiers[s]) modifiers[v].push_back(x); } sort(modifiers[v].begin(), modifiers[v].end()); int pos1 = (int)modifiers[v].size() / 2 - 1, pos2 = pos1+1, mid1 = modifiers[v][pos1], mid2 = modifiers[v][pos2]; LL prevSumDist = 0; for(int x : modifiers[v]) prevSumDist += abs(x-mid1); LL nowSumDist = abs(mid1-mid2); increase[v] += prevSumDist-nowSumDist; modifiers[v].clear(); modifiers[v].shrink_to_fit(); modifiers[v] = {mid1, mid2}; // debug(v); // for(int x : modifiers[v]) debug(x); // debug(increase[v]); } int main(){ input(); if(N == 2){ printf("%d\n", abs(value[1]-value[2])); return 0; } assert(N > M); dfs(N); assert(modifiers[N].size() == 2); printf("%lld\n", (abs(modifiers[N][0]-modifiers[N][1]) + increase[N])/2); } |