#include <bits/stdc++.h> using namespace std; const int N = 500005; int n, m, a, b; int value[N], prz[N]; int intervalBegin[N], intervalEnd[N]; vector<int> V[N]; long long ans; void DFS(int w) { if (value[w] != 0) { intervalBegin[w] = value[w]; intervalEnd[w] = value[w]; return; } vector<pair<int, int> > events; long long sum = 0; for (int i = 0; i < V[w].size(); i++) { int u = V[w][i]; if (prz[w] != u) { prz[u] = w; DFS(u); events.push_back({intervalBegin[u], 0}); events.push_back({intervalEnd[u], 1}); sum += intervalBegin[u] - 1; } } sort(events.begin(), events.end()); int currentRight = events.size() / 2; int currentLeft = 0; int bestIntervalBegin = 1; int bestIntervalEnd = 1; long long lastPosition = 1; long long bestSum = sum; for (int i = 0; i < events.size(); i++) { int position = events[i].first; sum += (position - lastPosition) * (currentLeft - currentRight); lastPosition = position; if (bestSum > sum) { bestSum = sum; bestIntervalBegin = position; } else if (bestSum == sum) { bestIntervalEnd = position; } int what = events[i].second; if (what == 0) { currentRight--; } else { currentLeft++; } if (i == events.size() - 1 || position != events[i + 1].first) { if (sum + currentLeft - currentRight < bestSum) { bestSum = sum; bestIntervalBegin = position; } } } intervalBegin[w] = bestIntervalBegin; intervalEnd[w] = bestIntervalEnd; ans += bestSum; } int main() { scanf("%d %d", &n, &m); for (int i = 1; i < n; i++) { scanf("%d %d", &a, &b); V[a].push_back(b); V[b].push_back(a); } for (int i = 1; i <= m; i++) { scanf("%d", &value[i]); } if (n == m) { printf("%d\n", abs(value[1] - value[2])); return 0; } DFS(m + 1); printf("%lld\n", ans); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | #include <bits/stdc++.h> using namespace std; const int N = 500005; int n, m, a, b; int value[N], prz[N]; int intervalBegin[N], intervalEnd[N]; vector<int> V[N]; long long ans; void DFS(int w) { if (value[w] != 0) { intervalBegin[w] = value[w]; intervalEnd[w] = value[w]; return; } vector<pair<int, int> > events; long long sum = 0; for (int i = 0; i < V[w].size(); i++) { int u = V[w][i]; if (prz[w] != u) { prz[u] = w; DFS(u); events.push_back({intervalBegin[u], 0}); events.push_back({intervalEnd[u], 1}); sum += intervalBegin[u] - 1; } } sort(events.begin(), events.end()); int currentRight = events.size() / 2; int currentLeft = 0; int bestIntervalBegin = 1; int bestIntervalEnd = 1; long long lastPosition = 1; long long bestSum = sum; for (int i = 0; i < events.size(); i++) { int position = events[i].first; sum += (position - lastPosition) * (currentLeft - currentRight); lastPosition = position; if (bestSum > sum) { bestSum = sum; bestIntervalBegin = position; } else if (bestSum == sum) { bestIntervalEnd = position; } int what = events[i].second; if (what == 0) { currentRight--; } else { currentLeft++; } if (i == events.size() - 1 || position != events[i + 1].first) { if (sum + currentLeft - currentRight < bestSum) { bestSum = sum; bestIntervalBegin = position; } } } intervalBegin[w] = bestIntervalBegin; intervalEnd[w] = bestIntervalEnd; ans += bestSum; } int main() { scanf("%d %d", &n, &m); for (int i = 1; i < n; i++) { scanf("%d %d", &a, &b); V[a].push_back(b); V[b].push_back(a); } for (int i = 1; i <= m; i++) { scanf("%d", &value[i]); } if (n == m) { printf("%d\n", abs(value[1] - value[2])); return 0; } DFS(m + 1); printf("%lld\n", ans); return 0; } |