#include <algorithm>
#include <cassert>
#include <cstdio>
#include <vector>
#include <set>
using namespace std;
const int MAX = 500000;
#define PB push_back
vector<int> G[MAX];
int W[MAX];
int n, k;
pair<int, int> P[MAX];
vector<int> tmp;
pair<int, int> dfs1(int u, int p) {
if (G[u].size() == 1) {
return {W[u], W[u]};
}
for (int v : G[u])
if (v != p) {
P[v] = dfs1(v, u);
}
tmp.clear();
for (int v : G[u])
if (v != p) {
tmp.PB(P[v].first);
tmp.PB(P[v].second);
}
sort(tmp.begin(), tmp.end());
return {tmp[(tmp.size() - 1) / 2], tmp[tmp.size() / 2]};
}
void dfs2(int u, int p) {
for (int v : G[u])
if (v != p && G[v].size() != 1) {
if (W[u] >= P[v].second) W[v] = P[v].second;
else if(W[u] <= P[v].first) W[v] = P[v].first;
else W[v] = W[u];
dfs2(v, u);
}
}
int main() {
scanf("%d %d", &n, &k);
for (int i=1;i<n;i++)
{
int a, b;
scanf("%d %d", &a, &b);
a--, b--;
G[a].PB(b);
G[b].PB(a);
}
for (int i=0;i<k;i++)
scanf("%d", W+i);
if(n == 2 && k == 2) {
printf("%d\n", abs(W[0] - W[1]));
return 0;
}
W[k] = dfs1(k, -1).first;
dfs2(k, -1);
long long res = 0;
for (int u=0;u<n;u++)
for (int v : G[u])
if (v < u)
res += abs(W[u] - W[v]);
printf("%lld\n", res);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | #include <algorithm> #include <cassert> #include <cstdio> #include <vector> #include <set> using namespace std; const int MAX = 500000; #define PB push_back vector<int> G[MAX]; int W[MAX]; int n, k; pair<int, int> P[MAX]; vector<int> tmp; pair<int, int> dfs1(int u, int p) { if (G[u].size() == 1) { return {W[u], W[u]}; } for (int v : G[u]) if (v != p) { P[v] = dfs1(v, u); } tmp.clear(); for (int v : G[u]) if (v != p) { tmp.PB(P[v].first); tmp.PB(P[v].second); } sort(tmp.begin(), tmp.end()); return {tmp[(tmp.size() - 1) / 2], tmp[tmp.size() / 2]}; } void dfs2(int u, int p) { for (int v : G[u]) if (v != p && G[v].size() != 1) { if (W[u] >= P[v].second) W[v] = P[v].second; else if(W[u] <= P[v].first) W[v] = P[v].first; else W[v] = W[u]; dfs2(v, u); } } int main() { scanf("%d %d", &n, &k); for (int i=1;i<n;i++) { int a, b; scanf("%d %d", &a, &b); a--, b--; G[a].PB(b); G[b].PB(a); } for (int i=0;i<k;i++) scanf("%d", W+i); if(n == 2 && k == 2) { printf("%d\n", abs(W[0] - W[1])); return 0; } W[k] = dfs1(k, -1).first; dfs2(k, -1); long long res = 0; for (int u=0;u<n;u++) for (int v : G[u]) if (v < u) res += abs(W[u] - W[v]); printf("%lld\n", res); return 0; } |
English