#include <algorithm> #include <cstdio> #include <cstdlib> #include <list> #include <map> #include <queue> #include <set> #include <stack> #include <vector> #include <cmath> #include <cstring> #include <string> #include <iostream> #include <complex> #include <sstream> #include <cassert> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef long double LD; typedef vector<int> VI; typedef pair<int,int> PII; #define REP(i,n) for(int i=0;i<(n);++i) #define SIZE(c) ((int)((c).size())) #define FOR(i,a,b) for (int i=(a); i<(b); ++i) #define FOREACH(i,x) for (__typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i) #define FORD(i,a,b) for (int i=(a)-1; i>=(b); --i) #define ALL(v) (v).begin(), (v).end() #define pb push_back #define mp make_pair #define st first #define nd second int N, M; list<int> adj[500005]; int R[500005]; LL mini[500005]; PII range[500005]; void go(int v, int p = -1) { if (v < M) { mini[v] = 0; range[v].st = range[v].nd = R[v]; return; } FOREACH(it, adj[v]) { if (*it == p) continue; go(*it, v); } vector<int> pts; FOREACH(it, adj[v]) { if (*it == p) continue; pts.pb(range[*it].st); pts.pb(range[*it].nd); } sort(pts.begin(), pts.end()); int K = pts.size() / 2; int val = pts[K-1]; mini[v] = 0; FOREACH(it, adj[v]) { if (*it == p) continue; mini[v] += mini[*it] + max(0, range[*it].st - val) + max(0, val - range[*it].nd); } range[v].st = pts[K-1]; range[v].nd = pts[K]; } int main() { scanf("%d%d", &N, &M); REP(i,N-1) { int a, b; scanf("%d%d", &a, &b); --a, --b; adj[a].pb(b); adj[b].pb(a); } REP(i,M) { scanf("%d", &R[i]); } if (M == N) { printf("%d\n", abs(R[1] - R[0])); return 0; } go(N-1); printf("%lld\n", mini[N-1]); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | #include <algorithm> #include <cstdio> #include <cstdlib> #include <list> #include <map> #include <queue> #include <set> #include <stack> #include <vector> #include <cmath> #include <cstring> #include <string> #include <iostream> #include <complex> #include <sstream> #include <cassert> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef long double LD; typedef vector<int> VI; typedef pair<int,int> PII; #define REP(i,n) for(int i=0;i<(n);++i) #define SIZE(c) ((int)((c).size())) #define FOR(i,a,b) for (int i=(a); i<(b); ++i) #define FOREACH(i,x) for (__typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i) #define FORD(i,a,b) for (int i=(a)-1; i>=(b); --i) #define ALL(v) (v).begin(), (v).end() #define pb push_back #define mp make_pair #define st first #define nd second int N, M; list<int> adj[500005]; int R[500005]; LL mini[500005]; PII range[500005]; void go(int v, int p = -1) { if (v < M) { mini[v] = 0; range[v].st = range[v].nd = R[v]; return; } FOREACH(it, adj[v]) { if (*it == p) continue; go(*it, v); } vector<int> pts; FOREACH(it, adj[v]) { if (*it == p) continue; pts.pb(range[*it].st); pts.pb(range[*it].nd); } sort(pts.begin(), pts.end()); int K = pts.size() / 2; int val = pts[K-1]; mini[v] = 0; FOREACH(it, adj[v]) { if (*it == p) continue; mini[v] += mini[*it] + max(0, range[*it].st - val) + max(0, val - range[*it].nd); } range[v].st = pts[K-1]; range[v].nd = pts[K]; } int main() { scanf("%d%d", &N, &M); REP(i,N-1) { int a, b; scanf("%d%d", &a, &b); --a, --b; adj[a].pb(b); adj[b].pb(a); } REP(i,M) { scanf("%d", &R[i]); } if (M == N) { printf("%d\n", abs(R[1] - R[0])); return 0; } go(N-1); printf("%lld\n", mini[N-1]); } |