#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <list>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
#include <cmath>
#include <cstring>
#include <string>
#include <iostream>
#include <complex>
#include <sstream>
#include <cassert>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef long double LD;
typedef vector<int> VI;
typedef pair<int,int> PII;
#define REP(i,n) for(int i=0;i<(n);++i)
#define SIZE(c) ((int)((c).size()))
#define FOR(i,a,b) for (int i=(a); i<(b); ++i)
#define FOREACH(i,x) for (__typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)
#define FORD(i,a,b) for (int i=(a)-1; i>=(b); --i)
#define ALL(v) (v).begin(), (v).end()
#define pb push_back
#define mp make_pair
#define st first
#define nd second
int N, M;
list<int> adj[500005];
int R[500005];
LL mini[500005];
PII range[500005];
void go(int v, int p = -1) {
if (v < M) {
mini[v] = 0;
range[v].st = range[v].nd = R[v];
return;
}
FOREACH(it, adj[v]) {
if (*it == p) continue;
go(*it, v);
}
vector<int> pts;
FOREACH(it, adj[v]) {
if (*it == p) continue;
pts.pb(range[*it].st);
pts.pb(range[*it].nd);
}
sort(pts.begin(), pts.end());
int K = pts.size() / 2;
int val = pts[K-1];
mini[v] = 0;
FOREACH(it, adj[v]) {
if (*it == p) continue;
mini[v] += mini[*it] + max(0, range[*it].st - val) + max(0, val - range[*it].nd);
}
range[v].st = pts[K-1];
range[v].nd = pts[K];
}
int main() {
scanf("%d%d", &N, &M);
REP(i,N-1) {
int a, b;
scanf("%d%d", &a, &b);
--a, --b;
adj[a].pb(b);
adj[b].pb(a);
}
REP(i,M) {
scanf("%d", &R[i]);
}
if (M == N) {
printf("%d\n", abs(R[1] - R[0]));
return 0;
}
go(N-1);
printf("%lld\n", mini[N-1]);
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | #include <algorithm> #include <cstdio> #include <cstdlib> #include <list> #include <map> #include <queue> #include <set> #include <stack> #include <vector> #include <cmath> #include <cstring> #include <string> #include <iostream> #include <complex> #include <sstream> #include <cassert> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef long double LD; typedef vector<int> VI; typedef pair<int,int> PII; #define REP(i,n) for(int i=0;i<(n);++i) #define SIZE(c) ((int)((c).size())) #define FOR(i,a,b) for (int i=(a); i<(b); ++i) #define FOREACH(i,x) for (__typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i) #define FORD(i,a,b) for (int i=(a)-1; i>=(b); --i) #define ALL(v) (v).begin(), (v).end() #define pb push_back #define mp make_pair #define st first #define nd second int N, M; list<int> adj[500005]; int R[500005]; LL mini[500005]; PII range[500005]; void go(int v, int p = -1) { if (v < M) { mini[v] = 0; range[v].st = range[v].nd = R[v]; return; } FOREACH(it, adj[v]) { if (*it == p) continue; go(*it, v); } vector<int> pts; FOREACH(it, adj[v]) { if (*it == p) continue; pts.pb(range[*it].st); pts.pb(range[*it].nd); } sort(pts.begin(), pts.end()); int K = pts.size() / 2; int val = pts[K-1]; mini[v] = 0; FOREACH(it, adj[v]) { if (*it == p) continue; mini[v] += mini[*it] + max(0, range[*it].st - val) + max(0, val - range[*it].nd); } range[v].st = pts[K-1]; range[v].nd = pts[K]; } int main() { scanf("%d%d", &N, &M); REP(i,N-1) { int a, b; scanf("%d%d", &a, &b); --a, --b; adj[a].pb(b); adj[b].pb(a); } REP(i,M) { scanf("%d", &R[i]); } if (M == N) { printf("%d\n", abs(R[1] - R[0])); return 0; } go(N-1); printf("%lld\n", mini[N-1]); } |
English