#include <bits/stdc++.h>
using namespace std;
#define PB push_back
#define FORE(i, t) for(__typeof(t.begin())i=t.begin();i!=t.end();++i)
#define SZ(x) int((x).size())
#define REP(i, n) for(int i=0,_=(n);i<_;++i)
#define FOR(i, a, b) for(int i=(a),_=(b);i<=_;++i)
#define FORD(i, a, b) for(int i=(a),_=(b);i>=_;--i)
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pii;
const int INF = 1e9 + 9;
const int MAX_N = 500003;
//struct Range {
//// ll cost;
// int a, b;
//};
ll diff(ll a, ll b) {
ll result = a - b;
if (result < 0) {
return -result;
}
return result;
}
int n, m;
int r[MAX_N];
int children[MAX_N];
vi adj[MAX_N];
bool vis[MAX_N];
pii ranges[MAX_N];
void add_edge(int a, int b) {
adj[a].PB(b);
adj[b].PB(a);
}
pii f(int x) {
vis[x] = true;
if (x <= m) {
return ranges[x] = pii(r[x], r[x]);
}
// vector <pii> ranges;
vi points;
FORE(yt, adj[x]) {
int y = *yt;
if (vis[y]) {
continue;
}
++children[x];
pii range = f(y);
points.PB(range.first);
points.PB(range.second);
}
sort(points.begin(), points.end());
return ranges[x] = pii(points[SZ(points) / 2 - 1], points[SZ(points) / 2]);
}
ll solve(int x, int val) {
// printf("x=%d val=%d\n", x, val);
ll result = 0;
if (val > ranges[x].second) {
result += diff(val, ranges[x].second);// * (ll) children[x];
val = ranges[x].second;
}
if (val < ranges[x].first) {
result += diff(val, ranges[x].first);// * (ll) children[x];
val = ranges[x].first;
}
// printf("x=%d start=%lld new_val=%d\n", x, result, val);
vis[x] = true;
FORE(yt, adj[x]) {
int y = *yt;
if (vis[y]) {
continue;
}
result += solve(y, val);
}
// printf("x=%d end=%lld\n", x, result);
return result;
}
void inline one() {
scanf("%d%d", &n, &m);
REP (i, n - 1) {
int a, b;
scanf("%d%d", &a, &b);
add_edge(a, b);
}
FOR (i, 1, n) {
if (i <= m) {
int x;
scanf("%d", &x);
r[i] = x;
} else {
r[i] = -1;
}
vis[i] = false;
children[i] = 0;
}
if (n == m) {
ll result = diff(r[1], r[2]);
printf("%lld\n", result);
return;
}
int root = n;
pii range = f(root);
// printf("%d..%d\n", range.first, range.second);
FOR (i, 1, n) {
vis[i] = false;
}
ll result = solve(root, range.first);
printf("%lld\n", result);
}
int main() {
//int z; scanf("%d", &z); while(z--)
one();
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | #include <bits/stdc++.h> using namespace std; #define PB push_back #define FORE(i, t) for(__typeof(t.begin())i=t.begin();i!=t.end();++i) #define SZ(x) int((x).size()) #define REP(i, n) for(int i=0,_=(n);i<_;++i) #define FOR(i, a, b) for(int i=(a),_=(b);i<=_;++i) #define FORD(i, a, b) for(int i=(a),_=(b);i>=_;--i) typedef long long ll; typedef vector<int> vi; typedef pair<int, int> pii; const int INF = 1e9 + 9; const int MAX_N = 500003; //struct Range { //// ll cost; // int a, b; //}; ll diff(ll a, ll b) { ll result = a - b; if (result < 0) { return -result; } return result; } int n, m; int r[MAX_N]; int children[MAX_N]; vi adj[MAX_N]; bool vis[MAX_N]; pii ranges[MAX_N]; void add_edge(int a, int b) { adj[a].PB(b); adj[b].PB(a); } pii f(int x) { vis[x] = true; if (x <= m) { return ranges[x] = pii(r[x], r[x]); } // vector <pii> ranges; vi points; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } ++children[x]; pii range = f(y); points.PB(range.first); points.PB(range.second); } sort(points.begin(), points.end()); return ranges[x] = pii(points[SZ(points) / 2 - 1], points[SZ(points) / 2]); } ll solve(int x, int val) { // printf("x=%d val=%d\n", x, val); ll result = 0; if (val > ranges[x].second) { result += diff(val, ranges[x].second);// * (ll) children[x]; val = ranges[x].second; } if (val < ranges[x].first) { result += diff(val, ranges[x].first);// * (ll) children[x]; val = ranges[x].first; } // printf("x=%d start=%lld new_val=%d\n", x, result, val); vis[x] = true; FORE(yt, adj[x]) { int y = *yt; if (vis[y]) { continue; } result += solve(y, val); } // printf("x=%d end=%lld\n", x, result); return result; } void inline one() { scanf("%d%d", &n, &m); REP (i, n - 1) { int a, b; scanf("%d%d", &a, &b); add_edge(a, b); } FOR (i, 1, n) { if (i <= m) { int x; scanf("%d", &x); r[i] = x; } else { r[i] = -1; } vis[i] = false; children[i] = 0; } if (n == m) { ll result = diff(r[1], r[2]); printf("%lld\n", result); return; } int root = n; pii range = f(root); // printf("%d..%d\n", range.first, range.second); FOR (i, 1, n) { vis[i] = false; } ll result = solve(root, range.first); printf("%lld\n", result); } int main() { //int z; scanf("%d", &z); while(z--) one(); } |
English