#include <cstdio> #include <algorithm> #include <vector> #include <queue> using namespace std; typedef long long ll; const ll inf = 1e12; ll ans; const int N = 500001; vector<int> g[N]; int r[N]; pair<int, int> dfs(int v, int prev) { if(r[v]) return make_pair(r[v], r[v]); vector<pair<int, bool>> vec; ll current = 0; int point = 0; for(int u: g[v]) if(u != prev) { auto p = dfs(u, v); vec.emplace_back(p.first, 0); vec.emplace_back(p.second, 1); current += p.first; } sort(vec.begin(), vec.end()); ll ans = current; int first = 0, last = 0, pre = 0, past = vec.size() / 2; for(auto p: vec) { current += (ll)(pre - past) * (p.first - point); point = p.first; if(current < ans) { ans = current; first = point; } if(current == ans) last = point; if(p.second) pre++; else past--; } ::ans += ans; return make_pair(first, last); } int main() { int n, m; scanf("%d %d", &n, &m); for(int i = 1; i < n; i++) { int a, b; scanf("%d %d", &a, &b); g[a].push_back(b); g[b].push_back(a); } for(int i = 1; i <= m; i++) scanf("%d", r + i); if(n == m) printf("%d\n", abs(r[1] - r[2])); else { dfs(m + 1, 0); printf("%lld\n", ans); } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | #include <cstdio> #include <algorithm> #include <vector> #include <queue> using namespace std; typedef long long ll; const ll inf = 1e12; ll ans; const int N = 500001; vector<int> g[N]; int r[N]; pair<int, int> dfs(int v, int prev) { if(r[v]) return make_pair(r[v], r[v]); vector<pair<int, bool>> vec; ll current = 0; int point = 0; for(int u: g[v]) if(u != prev) { auto p = dfs(u, v); vec.emplace_back(p.first, 0); vec.emplace_back(p.second, 1); current += p.first; } sort(vec.begin(), vec.end()); ll ans = current; int first = 0, last = 0, pre = 0, past = vec.size() / 2; for(auto p: vec) { current += (ll)(pre - past) * (p.first - point); point = p.first; if(current < ans) { ans = current; first = point; } if(current == ans) last = point; if(p.second) pre++; else past--; } ::ans += ans; return make_pair(first, last); } int main() { int n, m; scanf("%d %d", &n, &m); for(int i = 1; i < n; i++) { int a, b; scanf("%d %d", &a, &b); g[a].push_back(b); g[b].push_back(a); } for(int i = 1; i <= m; i++) scanf("%d", r + i); if(n == m) printf("%d\n", abs(r[1] - r[2])); else { dfs(m + 1, 0); printf("%lld\n", ans); } } |