#include <bits/stdc++.h> using namespace std; #define PB push_back #define FORE(i, t) for(__typeof(t.begin())i=t.begin();i!=t.end();++i) #define SZ(x) int((x).size()) #define REP(i, n) for(int i=0,_=(n);i<_;++i) #define FOR(i, a, b) for(int i=(a),_=(b);i<=_;++i) #define FORD(i, a, b) for(int i=(a),_=(b);i>=_;--i) typedef long long ll; typedef vector<int> vi; typedef pair<int, int> pii; const int INF = 1e9 + 9; const int MAX_N = 500003; //struct Range { //// ll cost; // int a, b; //}; ll diff(ll a, ll b) { ll result = a - b; if (result < 0) { return -result; } return result; } int n, m; int r[MAX_N]; int children[MAX_N]; vi adj[MAX_N]; bool vis[MAX_N]; pii ranges[MAX_N]; void add_edge(int a, int b) { adj[a].PB(b); adj[b].PB(a); } pii f(int x) { vis[x] = true; if (x <= m) { return ranges[x] = pii(r[x], r[x]); } // vector <pii> ranges; vi points; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } ++children[x]; pii range = f(y); points.PB(range.first); points.PB(range.second); } sort(points.begin(), points.end()); return ranges[x] = pii(points[SZ(points) / 2 - 1], points[SZ(points) / 2]); } ll solve(int x, int val) { // printf("x=%d val=%d\n", x, val); ll result = 0; if (val > ranges[x].second) { result += diff(val, ranges[x].second);// * (ll) children[x]; val = ranges[x].second; } if (val < ranges[x].first) { result += diff(val, ranges[x].first);// * (ll) children[x]; val = ranges[x].first; } // printf("x=%d start=%lld new_val=%d\n", x, result, val); vis[x] = true; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } result += solve(y, val); } // printf("x=%d end=%lld\n", x, result); return result; } void inline one() { scanf("%d%d", &n, &m); REP (i, n - 1) { int a, b; scanf("%d%d", &a, &b); add_edge(a, b); } FOR (i, 1, n) { if (i <= m) { int x; scanf("%d", &x); r[i] = x; } else { r[i] = -1; } vis[i] = false; children[i] = 0; } if (n == m) { ll result = diff(r[1], r[2]); printf("%lld\n", result); return; } int root = n; pii range = f(root); // printf("%d..%d\n", range.first, range.second); FOR (i, 1, n) { vis[i] = false; } ll result = solve(root, range.first); printf("%lld\n", result); } int main() { //int z; scanf("%d", &z); while(z--) one(); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | #include <bits/stdc++.h> using namespace std; #define PB push_back #define FORE(i, t) for(__typeof(t.begin())i=t.begin();i!=t.end();++i) #define SZ(x) int((x).size()) #define REP(i, n) for(int i=0,_=(n);i<_;++i) #define FOR(i, a, b) for(int i=(a),_=(b);i<=_;++i) #define FORD(i, a, b) for(int i=(a),_=(b);i>=_;--i) typedef long long ll; typedef vector<int> vi; typedef pair<int, int> pii; const int INF = 1e9 + 9; const int MAX_N = 500003; //struct Range { //// ll cost; // int a, b; //}; ll diff(ll a, ll b) { ll result = a - b; if (result < 0) { return -result; } return result; } int n, m; int r[MAX_N]; int children[MAX_N]; vi adj[MAX_N]; bool vis[MAX_N]; pii ranges[MAX_N]; void add_edge(int a, int b) { adj[a].PB(b); adj[b].PB(a); } pii f(int x) { vis[x] = true; if (x <= m) { return ranges[x] = pii(r[x], r[x]); } // vector <pii> ranges; vi points; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } ++children[x]; pii range = f(y); points.PB(range.first); points.PB(range.second); } sort(points.begin(), points.end()); return ranges[x] = pii(points[SZ(points) / 2 - 1], points[SZ(points) / 2]); } ll solve(int x, int val) { // printf("x=%d val=%d\n", x, val); ll result = 0; if (val > ranges[x].second) { result += diff(val, ranges[x].second);// * (ll) children[x]; val = ranges[x].second; } if (val < ranges[x].first) { result += diff(val, ranges[x].first);// * (ll) children[x]; val = ranges[x].first; } // printf("x=%d start=%lld new_val=%d\n", x, result, val); vis[x] = true; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } result += solve(y, val); } // printf("x=%d end=%lld\n", x, result); return result; } void inline one() { scanf("%d%d", &n, &m); REP (i, n - 1) { int a, b; scanf("%d%d", &a, &b); add_edge(a, b); } FOR (i, 1, n) { if (i <= m) { int x; scanf("%d", &x); r[i] = x; } else { r[i] = -1; } vis[i] = false; children[i] = 0; } if (n == m) { ll result = diff(r[1], r[2]); printf("%lld\n", result); return; } int root = n; pii range = f(root); // printf("%d..%d\n", range.first, range.second); FOR (i, 1, n) { vis[i] = false; } ll result = solve(root, range.first); printf("%lld\n", result); } int main() { //int z; scanf("%d", &z); while(z--) one(); } |