#include <iostream> #include <vector> #include <algorithm> const int MAX = 500010; struct Node { int l, r; int w; int p; bool col; int getPoint(int l) { return std::max(this->l, l); } }; int n,m; long long result = 0; std::vector<int> G[MAX]; Node V[MAX]; void Solve(int v, int parent = -1) { int l = MAX; for (int i : G[v]) if (i != parent) { Solve(i, v); l = std::min(l, V[i].r); } if (v < m) // leaf { V[v].l = V[v].r = V[v].w; } else { std::vector<int> med; for (int i : G[v]) if (i != parent) { V[i].w = V[i].getPoint(l); med.push_back(V[i].w); } std::sort(med.begin(), med.end()); V[v].l = med[(med.size()-1)/2]; V[v].r = med[med.size()/2]; if (parent == -1) V[v].w = V[v].l; } } long long Count(int v, int parent = -1) { long long result = 0; for (int i : G[v]) if (i != parent) { result += Count(i, v); result += abs(V[v].w - V[i].w); } return result; } int main(int argc, char **argv) { std::ios_base::sync_with_stdio(0); std::cin >> n >> m; for (int i=1;i<n;++i) { int a,b; std::cin >> a >> b; --a; --b; G[a].push_back(b); G[b].push_back(a); } for (int i=0;i<m;++i) std::cin >> V[i].w; Solve(n-1); std::cout << Count(n-1) << std::endl; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | #include <iostream> #include <vector> #include <algorithm> const int MAX = 500010; struct Node { int l, r; int w; int p; bool col; int getPoint(int l) { return std::max(this->l, l); } }; int n,m; long long result = 0; std::vector<int> G[MAX]; Node V[MAX]; void Solve(int v, int parent = -1) { int l = MAX; for (int i : G[v]) if (i != parent) { Solve(i, v); l = std::min(l, V[i].r); } if (v < m) // leaf { V[v].l = V[v].r = V[v].w; } else { std::vector<int> med; for (int i : G[v]) if (i != parent) { V[i].w = V[i].getPoint(l); med.push_back(V[i].w); } std::sort(med.begin(), med.end()); V[v].l = med[(med.size()-1)/2]; V[v].r = med[med.size()/2]; if (parent == -1) V[v].w = V[v].l; } } long long Count(int v, int parent = -1) { long long result = 0; for (int i : G[v]) if (i != parent) { result += Count(i, v); result += abs(V[v].w - V[i].w); } return result; } int main(int argc, char **argv) { std::ios_base::sync_with_stdio(0); std::cin >> n >> m; for (int i=1;i<n;++i) { int a,b; std::cin >> a >> b; --a; --b; G[a].push_back(b); G[b].push_back(a); } for (int i=0;i<m;++i) std::cin >> V[i].w; Solve(n-1); std::cout << Count(n-1) << std::endl; return 0; } |