#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
int n;
int m;
long long ret = 0;
vector<bool> visited;
vector<int> boundary;
vector<int> r;
vector<pair<int, int> > edges;
pair<int, int> Dfs(const int a) {
pair<int, int> p;
if (a < m) {
p = make_pair(r[a], r[a]);
} else {
visited[a] = true;
vector<int> mis;
vector<int> mas;
for (int i = boundary[a]; i < boundary[a + 1]; ++i) {
if (!visited[edges[i].second]) {
p = Dfs(edges[i].second);
mis.push_back(p.first);
mas.push_back(p.second);
}
}
sort(mis.rbegin(), mis.rend());
sort(mas.begin(), mas.end());
int i = 0;
while (mis[i] > mas[i]) {
ret += mis[i] - mas[i];
++i;
}
if (i == 0)
p = make_pair(mis[i], mas[i]);
else if (2 * i < mis.size())
p = make_pair(max(mas[i - 1], mis[i]), min(mis[i - 1], mas[i]));
else
p = make_pair(mas[i - 1], mis[i - 1]);
}
// cerr << a << ' ' << p.first << ' ' << p.second << endl;
return p;
}
void Go() {
visited.resize(0);
visited.resize(n, false);
Dfs(n - 1);
}
int main() {
scanf("%d%d", &n, &m);
vector<int> u(n - 1), v(n - 1);
for (int i = 0; i < n - 1; ++i) {
scanf("%d%d", &u[i], &v[i]);
--u[i];
--v[i];
}
r.resize(m);
for (int i = 0; i < m; ++i) scanf("%d", &r[i]);
if (n == m) {
for (int i = 0; i < n - 1; ++i) ret += abs(r[u[i]] - r[v[i]]);
} else {
for (int i = 0; i < n - 1; ++i) {
edges.push_back(make_pair(u[i], v[i]));
edges.push_back(make_pair(v[i], u[i]));
}
sort(edges.begin(), edges.end());
int i = 0;
while (i < edges.size()) {
while (edges[i].first >= boundary.size()) boundary.push_back(i);
++i;
}
while (boundary.size() <= n) boundary.push_back(i);
Go();
}
printf("%lld\n", ret);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | #include <algorithm> #include <cstdio> #include <iostream> #include <vector> using namespace std; int n; int m; long long ret = 0; vector<bool> visited; vector<int> boundary; vector<int> r; vector<pair<int, int> > edges; pair<int, int> Dfs(const int a) { pair<int, int> p; if (a < m) { p = make_pair(r[a], r[a]); } else { visited[a] = true; vector<int> mis; vector<int> mas; for (int i = boundary[a]; i < boundary[a + 1]; ++i) { if (!visited[edges[i].second]) { p = Dfs(edges[i].second); mis.push_back(p.first); mas.push_back(p.second); } } sort(mis.rbegin(), mis.rend()); sort(mas.begin(), mas.end()); int i = 0; while (mis[i] > mas[i]) { ret += mis[i] - mas[i]; ++i; } if (i == 0) p = make_pair(mis[i], mas[i]); else if (2 * i < mis.size()) p = make_pair(max(mas[i - 1], mis[i]), min(mis[i - 1], mas[i])); else p = make_pair(mas[i - 1], mis[i - 1]); } // cerr << a << ' ' << p.first << ' ' << p.second << endl; return p; } void Go() { visited.resize(0); visited.resize(n, false); Dfs(n - 1); } int main() { scanf("%d%d", &n, &m); vector<int> u(n - 1), v(n - 1); for (int i = 0; i < n - 1; ++i) { scanf("%d%d", &u[i], &v[i]); --u[i]; --v[i]; } r.resize(m); for (int i = 0; i < m; ++i) scanf("%d", &r[i]); if (n == m) { for (int i = 0; i < n - 1; ++i) ret += abs(r[u[i]] - r[v[i]]); } else { for (int i = 0; i < n - 1; ++i) { edges.push_back(make_pair(u[i], v[i])); edges.push_back(make_pair(v[i], u[i])); } sort(edges.begin(), edges.end()); int i = 0; while (i < edges.size()) { while (edges[i].first >= boundary.size()) boundary.push_back(i); ++i; } while (boundary.size() <= n) boundary.push_back(i); Go(); } printf("%lld\n", ret); return 0; } |
English