#include<bits/stdc++.h>
using namespace std;
#define REP(i, n) for(int i = 0; i < (n); i++)
#define SIZE(s) ((int) (s).size())
#define SCAND(x) assert(scanf("%d", x) == 1)
#define SCANS(x) assert(scanf("%s", x) == 1)
#define ALL(s) s.begin(), s.end()
#define MP make_pair
#define ST first
#define ND second
using LL = long long;
using PII = pair<int, int>;
const int N = 5e5;
int n, m;
vector<int> g[N];
int r[N];
PII ranges[N];
void dfs1(int u, int p = -1){
if(u < m){
ranges[u] = MP(r[u], r[u]);
return;
}
vector<PII> events;
for(int v : g[u]) if(v != p){
dfs1(v, u);
events.emplace_back(ranges[v].ST, -1);
events.emplace_back(ranges[v].ND, +1);
}
int k = events.size() / 2;
nth_element(events.begin(), events.begin() + k - 1, events.end());
nth_element(events.begin() + k, events.begin() + k, events.end());
ranges[u] = MP(events[k-1].ST, events[k].ST);
}
void dfs2(int u, int p = -1){
if(u < m) return;
if(p == -1){
r[u] = ranges[u].ST;
} else {
r[u] = r[p];
if(r[u] < ranges[u].ST) r[u] = ranges[u].ST;
if(r[u] > ranges[u].ND) r[u] = ranges[u].ND;
}
for(int v : g[u]) if(v != p){
dfs2(v, u);
}
}
int main(){
assert(scanf("%d%d", &n, &m) == 2);
REP(i, n-1){
int u, v;
assert(scanf("%d%d", &u, &v) == 2);
u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
REP(i, m) assert(scanf("%d", &r[i]) == 1);
if(m == n){
printf("%d\n", abs(r[0] - r[1]));
return 0;
}
dfs1(m);
dfs2(m);
LL res = 0;
REP(u, n){
for(int v : g[u]) if(u < v) res += abs(r[u] - r[v]);
}
printf("%lld\n", res);
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | #include<bits/stdc++.h> using namespace std; #define REP(i, n) for(int i = 0; i < (n); i++) #define SIZE(s) ((int) (s).size()) #define SCAND(x) assert(scanf("%d", x) == 1) #define SCANS(x) assert(scanf("%s", x) == 1) #define ALL(s) s.begin(), s.end() #define MP make_pair #define ST first #define ND second using LL = long long; using PII = pair<int, int>; const int N = 5e5; int n, m; vector<int> g[N]; int r[N]; PII ranges[N]; void dfs1(int u, int p = -1){ if(u < m){ ranges[u] = MP(r[u], r[u]); return; } vector<PII> events; for(int v : g[u]) if(v != p){ dfs1(v, u); events.emplace_back(ranges[v].ST, -1); events.emplace_back(ranges[v].ND, +1); } int k = events.size() / 2; nth_element(events.begin(), events.begin() + k - 1, events.end()); nth_element(events.begin() + k, events.begin() + k, events.end()); ranges[u] = MP(events[k-1].ST, events[k].ST); } void dfs2(int u, int p = -1){ if(u < m) return; if(p == -1){ r[u] = ranges[u].ST; } else { r[u] = r[p]; if(r[u] < ranges[u].ST) r[u] = ranges[u].ST; if(r[u] > ranges[u].ND) r[u] = ranges[u].ND; } for(int v : g[u]) if(v != p){ dfs2(v, u); } } int main(){ assert(scanf("%d%d", &n, &m) == 2); REP(i, n-1){ int u, v; assert(scanf("%d%d", &u, &v) == 2); u--; v--; g[u].push_back(v); g[v].push_back(u); } REP(i, m) assert(scanf("%d", &r[i]) == 1); if(m == n){ printf("%d\n", abs(r[0] - r[1])); return 0; } dfs1(m); dfs2(m); LL res = 0; REP(u, n){ for(int v : g[u]) if(u < v) res += abs(r[u] - r[v]); } printf("%lld\n", res); } |
English