#include <algorithm> #include <cstdio> #include <iostream> #include <vector> using namespace std; int n; int m; long long ret = 0; vector<bool> visited; vector<int> boundary; vector<int> r; vector<pair<int, int> > edges; pair<int, int> Dfs(const int a) { pair<int, int> p; if (a < m) { p = make_pair(r[a], r[a]); } else { visited[a] = true; vector<int> mis; vector<int> mas; for (int i = boundary[a]; i < boundary[a + 1]; ++i) { if (!visited[edges[i].second]) { p = Dfs(edges[i].second); mis.push_back(p.first); mas.push_back(p.second); } } sort(mis.rbegin(), mis.rend()); sort(mas.begin(), mas.end()); int i = 0; while (mis[i] > mas[i]) { ret += mis[i] - mas[i]; ++i; } if (i == 0) p = make_pair(mis[i], mas[i]); else if (2 * i < mis.size()) p = make_pair(max(mas[i - 1], mis[i]), min(mis[i - 1], mas[i])); else p = make_pair(mas[i - 1], mis[i - 1]); } // cerr << a << ' ' << p.first << ' ' << p.second << endl; return p; } void Go() { visited.resize(0); visited.resize(n, false); Dfs(n - 1); } int main() { scanf("%d%d", &n, &m); vector<int> u(n - 1), v(n - 1); for (int i = 0; i < n - 1; ++i) { scanf("%d%d", &u[i], &v[i]); --u[i]; --v[i]; } r.resize(m); for (int i = 0; i < m; ++i) scanf("%d", &r[i]); if (n == m) { for (int i = 0; i < n - 1; ++i) ret += abs(r[u[i]] - r[v[i]]); } else { for (int i = 0; i < n - 1; ++i) { edges.push_back(make_pair(u[i], v[i])); edges.push_back(make_pair(v[i], u[i])); } sort(edges.begin(), edges.end()); int i = 0; while (i < edges.size()) { while (edges[i].first >= boundary.size()) boundary.push_back(i); ++i; } while (boundary.size() <= n) boundary.push_back(i); Go(); } printf("%lld\n", ret); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | #include <algorithm> #include <cstdio> #include <iostream> #include <vector> using namespace std; int n; int m; long long ret = 0; vector<bool> visited; vector<int> boundary; vector<int> r; vector<pair<int, int> > edges; pair<int, int> Dfs(const int a) { pair<int, int> p; if (a < m) { p = make_pair(r[a], r[a]); } else { visited[a] = true; vector<int> mis; vector<int> mas; for (int i = boundary[a]; i < boundary[a + 1]; ++i) { if (!visited[edges[i].second]) { p = Dfs(edges[i].second); mis.push_back(p.first); mas.push_back(p.second); } } sort(mis.rbegin(), mis.rend()); sort(mas.begin(), mas.end()); int i = 0; while (mis[i] > mas[i]) { ret += mis[i] - mas[i]; ++i; } if (i == 0) p = make_pair(mis[i], mas[i]); else if (2 * i < mis.size()) p = make_pair(max(mas[i - 1], mis[i]), min(mis[i - 1], mas[i])); else p = make_pair(mas[i - 1], mis[i - 1]); } // cerr << a << ' ' << p.first << ' ' << p.second << endl; return p; } void Go() { visited.resize(0); visited.resize(n, false); Dfs(n - 1); } int main() { scanf("%d%d", &n, &m); vector<int> u(n - 1), v(n - 1); for (int i = 0; i < n - 1; ++i) { scanf("%d%d", &u[i], &v[i]); --u[i]; --v[i]; } r.resize(m); for (int i = 0; i < m; ++i) scanf("%d", &r[i]); if (n == m) { for (int i = 0; i < n - 1; ++i) ret += abs(r[u[i]] - r[v[i]]); } else { for (int i = 0; i < n - 1; ++i) { edges.push_back(make_pair(u[i], v[i])); edges.push_back(make_pair(v[i], u[i])); } sort(edges.begin(), edges.end()); int i = 0; while (i < edges.size()) { while (edges[i].first >= boundary.size()) boundary.push_back(i); ++i; } while (boundary.size() <= n) boundary.push_back(i); Go(); } printf("%lld\n", ret); return 0; } |