#include <algorithm> #include <cassert> #include <cstdio> #include <vector> #include <set> using namespace std; const int MAX = 500000; #define PB push_back vector<int> G[MAX]; int W[MAX]; int n, k; pair<int, int> P[MAX]; vector<int> tmp; pair<int, int> dfs1(int u, int p) { if (G[u].size() == 1) { return {W[u], W[u]}; } for (int v : G[u]) if (v != p) { P[v] = dfs1(v, u); } tmp.clear(); for (int v : G[u]) if (v != p) { tmp.PB(P[v].first); tmp.PB(P[v].second); } sort(tmp.begin(), tmp.end()); return {tmp[(tmp.size() - 1) / 2], tmp[tmp.size() / 2]}; } void dfs2(int u, int p) { for (int v : G[u]) if (v != p && G[v].size() != 1) { if (W[u] >= P[v].second) W[v] = P[v].second; else if(W[u] <= P[v].first) W[v] = P[v].first; else W[v] = W[u]; dfs2(v, u); } } int main() { scanf("%d %d", &n, &k); for (int i=1;i<n;i++) { int a, b; scanf("%d %d", &a, &b); a--, b--; G[a].PB(b); G[b].PB(a); } for (int i=0;i<k;i++) scanf("%d", W+i); if(n == 2 && k == 2) { printf("%d\n", abs(W[0] - W[1])); return 0; } W[k] = dfs1(k, -1).first; dfs2(k, -1); long long res = 0; for (int u=0;u<n;u++) for (int v : G[u]) if (v < u) res += abs(W[u] - W[v]); printf("%lld\n", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | #include <algorithm> #include <cassert> #include <cstdio> #include <vector> #include <set> using namespace std; const int MAX = 500000; #define PB push_back vector<int> G[MAX]; int W[MAX]; int n, k; pair<int, int> P[MAX]; vector<int> tmp; pair<int, int> dfs1(int u, int p) { if (G[u].size() == 1) { return {W[u], W[u]}; } for (int v : G[u]) if (v != p) { P[v] = dfs1(v, u); } tmp.clear(); for (int v : G[u]) if (v != p) { tmp.PB(P[v].first); tmp.PB(P[v].second); } sort(tmp.begin(), tmp.end()); return {tmp[(tmp.size() - 1) / 2], tmp[tmp.size() / 2]}; } void dfs2(int u, int p) { for (int v : G[u]) if (v != p && G[v].size() != 1) { if (W[u] >= P[v].second) W[v] = P[v].second; else if(W[u] <= P[v].first) W[v] = P[v].first; else W[v] = W[u]; dfs2(v, u); } } int main() { scanf("%d %d", &n, &k); for (int i=1;i<n;i++) { int a, b; scanf("%d %d", &a, &b); a--, b--; G[a].PB(b); G[b].PB(a); } for (int i=0;i<k;i++) scanf("%d", W+i); if(n == 2 && k == 2) { printf("%d\n", abs(W[0] - W[1])); return 0; } W[k] = dfs1(k, -1).first; dfs2(k, -1); long long res = 0; for (int u=0;u<n;u++) for (int v : G[u]) if (v < u) res += abs(W[u] - W[v]); printf("%lld\n", res); return 0; } |