#include <iostream>
#include <vector>
#include <algorithm>
const int MAX = 500010;
struct Node
{
int l, r;
int w;
int p;
bool col;
int getPoint(int l)
{
return std::max(this->l, l);
}
};
int n,m;
long long result = 0;
std::vector<int> G[MAX];
Node V[MAX];
void Solve(int v, int parent = -1)
{
int l = MAX;
for (int i : G[v])
if (i != parent)
{
Solve(i, v);
l = std::min(l, V[i].r);
}
if (v < m) // leaf
{
V[v].l = V[v].r = V[v].w;
}
else
{
std::vector<int> med;
for (int i : G[v])
if (i != parent)
{
V[i].w = V[i].getPoint(l);
med.push_back(V[i].w);
}
std::sort(med.begin(), med.end());
V[v].l = med[(med.size()-1)/2];
V[v].r = med[med.size()/2];
if (parent == -1)
V[v].w = V[v].l;
}
}
long long Count(int v, int parent = -1)
{
long long result = 0;
for (int i : G[v])
if (i != parent)
{
result += Count(i, v);
result += abs(V[v].w - V[i].w);
}
return result;
}
int main(int argc, char **argv)
{
std::ios_base::sync_with_stdio(0);
std::cin >> n >> m;
for (int i=1;i<n;++i)
{
int a,b;
std::cin >> a >> b;
--a;
--b;
G[a].push_back(b);
G[b].push_back(a);
}
for (int i=0;i<m;++i)
std::cin >> V[i].w;
Solve(n-1);
std::cout << Count(n-1) << std::endl;
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | #include <iostream> #include <vector> #include <algorithm> const int MAX = 500010; struct Node { int l, r; int w; int p; bool col; int getPoint(int l) { return std::max(this->l, l); } }; int n,m; long long result = 0; std::vector<int> G[MAX]; Node V[MAX]; void Solve(int v, int parent = -1) { int l = MAX; for (int i : G[v]) if (i != parent) { Solve(i, v); l = std::min(l, V[i].r); } if (v < m) // leaf { V[v].l = V[v].r = V[v].w; } else { std::vector<int> med; for (int i : G[v]) if (i != parent) { V[i].w = V[i].getPoint(l); med.push_back(V[i].w); } std::sort(med.begin(), med.end()); V[v].l = med[(med.size()-1)/2]; V[v].r = med[med.size()/2]; if (parent == -1) V[v].w = V[v].l; } } long long Count(int v, int parent = -1) { long long result = 0; for (int i : G[v]) if (i != parent) { result += Count(i, v); result += abs(V[v].w - V[i].w); } return result; } int main(int argc, char **argv) { std::ios_base::sync_with_stdio(0); std::cin >> n >> m; for (int i=1;i<n;++i) { int a,b; std::cin >> a >> b; --a; --b; G[a].push_back(b); G[b].push_back(a); } for (int i=0;i<m;++i) std::cin >> V[i].w; Solve(n-1); std::cout << Count(n-1) << std::endl; return 0; } |
English