#include<algorithm> #include<cstdio> #include<queue> #include<utility> #include<vector> #define VERTICES 500000 #define FOR(i, a, b) for (int i = (a); i < (b); ++i) #define REP(i, n) FOR(i, 0, n) using namespace std; vector<int> graph[VERTICES]; int deg[VERTICES]; queue<int> Q; int r[VERTICES]; pair<int, int> intervals[VERTICES]; int parent[VERTICES]; int main() { int n, m, a, b; scanf("%d%d", &n, &m); REP(i, n-1) { scanf("%d%d", &a, &b); --a; --b; graph[a].push_back(b); graph[b].push_back(a); } REP(i, n) { deg[i] = graph[i].size(); parent[i] = i; } REP(i, m) { Q.push(i); scanf("%d", r + i); } if ( ( n == 2) && ( m == 2) ) { printf("%d\n", (max(r[0], r[1]) - min(r[0], r[1]))); return 0; } long long cost = 0; while (!Q.empty()) { int element = Q.front(); Q.pop(); for (int neighbour : graph[element]) { deg[neighbour]--; if (deg[neighbour] == 1) { Q.push(neighbour); } if (parent[neighbour] == neighbour) { parent[element] = neighbour; // printf("parent[%d] = %d\n", element + 1, neighbour + 1); } } if (element < m) { intervals[element] = make_pair(r[element], r[element]); } else { vector<int> points; for (int neighbour : graph[element]) { if (parent[neighbour] == element) { points.push_back(intervals[neighbour].first); points.push_back(intervals[neighbour].second); } } sort(points.begin(), points.end()); intervals[element] = make_pair(points[(points.size() - 1) / 2], points[points.size() / 2]); //printf("Creating interval for element(%d): [%d, %d]\n", // element + 1, // points[(points.size() - 1) / 2], // points[points.size() / 2]); for (int neighbour : graph[element]) { if (parent[neighbour] == element) { if (intervals[neighbour].second < intervals[element].first) { cost += intervals[element].first - intervals[neighbour].second; //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n", // cost, // intervals[element].first, // intervals[neighbour].second, // neighbour + 1, // element + 1); } if (intervals[neighbour].first > intervals[element].first) { cost += intervals[neighbour].first - intervals[element].first; //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n", // cost, // intervals[neighbour].first, // intervals[element].first, // neighbour + 1, // element + 1); } } } } } printf("%lld\n", cost); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | #include<algorithm> #include<cstdio> #include<queue> #include<utility> #include<vector> #define VERTICES 500000 #define FOR(i, a, b) for (int i = (a); i < (b); ++i) #define REP(i, n) FOR(i, 0, n) using namespace std; vector<int> graph[VERTICES]; int deg[VERTICES]; queue<int> Q; int r[VERTICES]; pair<int, int> intervals[VERTICES]; int parent[VERTICES]; int main() { int n, m, a, b; scanf("%d%d", &n, &m); REP(i, n-1) { scanf("%d%d", &a, &b); --a; --b; graph[a].push_back(b); graph[b].push_back(a); } REP(i, n) { deg[i] = graph[i].size(); parent[i] = i; } REP(i, m) { Q.push(i); scanf("%d", r + i); } if ( ( n == 2) && ( m == 2) ) { printf("%d\n", (max(r[0], r[1]) - min(r[0], r[1]))); return 0; } long long cost = 0; while (!Q.empty()) { int element = Q.front(); Q.pop(); for (int neighbour : graph[element]) { deg[neighbour]--; if (deg[neighbour] == 1) { Q.push(neighbour); } if (parent[neighbour] == neighbour) { parent[element] = neighbour; // printf("parent[%d] = %d\n", element + 1, neighbour + 1); } } if (element < m) { intervals[element] = make_pair(r[element], r[element]); } else { vector<int> points; for (int neighbour : graph[element]) { if (parent[neighbour] == element) { points.push_back(intervals[neighbour].first); points.push_back(intervals[neighbour].second); } } sort(points.begin(), points.end()); intervals[element] = make_pair(points[(points.size() - 1) / 2], points[points.size() / 2]); //printf("Creating interval for element(%d): [%d, %d]\n", // element + 1, // points[(points.size() - 1) / 2], // points[points.size() / 2]); for (int neighbour : graph[element]) { if (parent[neighbour] == element) { if (intervals[neighbour].second < intervals[element].first) { cost += intervals[element].first - intervals[neighbour].second; //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n", // cost, // intervals[element].first, // intervals[neighbour].second, // neighbour + 1, // element + 1); } if (intervals[neighbour].first > intervals[element].first) { cost += intervals[neighbour].first - intervals[element].first; //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n", // cost, // intervals[neighbour].first, // intervals[element].first, // neighbour + 1, // element + 1); } } } } } printf("%lld\n", cost); return 0; } |