#include <iomanip> #include <iostream> #include <utility> #include <algorithm> #include <cassert> #include <string> #include <vector> #include <set> #include <map> using namespace std; #define ALL(x) x.begin(), x.end() #define VAR(a,b) __typeof (b) a = b #define IN(a) int a; cin >> a #define IN2(a,b) int a, b; cin >> a >> b #define REP(i,n) for (int _n=(n), i=0; i<_n; ++i) #define FOR(i,a,b) for (int _b=(b), i=(a); i<=_b; ++i) #define FORD(i,a,b) for (int _b=(b), i=(a); i>=_b; --i) #define FORE(i,a) for (VAR(i,a.begin ()); i!=a.end (); ++i) #define PB push_back #define MP make_pair #define ST first #define ND second typedef vector<int> VI; typedef long long LL; typedef pair<int,int> PII; typedef double LD; const int DBG = 0, INF = int(1e9); vector<VI> v; VI r; int n, m; LL res = 0; vector<PII> res_range; void dfs(int node, int parent) { if (node < m) { res_range[node] = MP(r[node], r[node]); return; } vector<PII> children; children.reserve(v[node].size()); FORE(it, v[node]) if (*it != parent) children.PB(res_range[*it]); LL beg_sum = 0, end_sum = 0; int begs_left = children.size(), ends_cnt = 0; vector<PII> events; FORE(it, children) { events.PB(MP(it->first, 1)); beg_sum += it->first; events.PB(MP(it->second, 0)); } sort(ALL(events)); LL best_val = beg_sum; int best_beg = 0, best_end = 0; FORE(it, events) { int pos = it->first, type = it->second; LL cur = beg_sum - LL(begs_left) * pos + LL(ends_cnt) * pos - end_sum; if (cur < best_val) { best_val = cur; best_beg = best_end = pos; } else if (cur == best_val) best_end = pos; if (type == 1) { beg_sum -= pos; begs_left--; } else { end_sum += pos; ends_cnt++; } } res += best_val; res_range[node] = MP(best_beg, best_end); } VI vis; struct state { int node, parent, first_vis; state(int node, int parent, int first_vis) : node(node), parent(parent), first_vis(first_vis) {} }; int main() { ios_base::sync_with_stdio(0); cout.setf(ios::fixed); cin >> n >> m; v.resize(n); REP(i, n - 1) { IN2(a, b); --a; --b; v[a].PB(b); v[b].PB(a); } r.resize(m); REP(i,m) cin >> r[i]; if (n == 2) { assert(m == 2); cout << abs(r[0] - r[1]) << endl; return 0; } vis = VI(n, 0); vector<state> st; st.PB(state(m, -1, 1)); vis[m] = 1; res_range.resize(n); while (!st.empty()) { int nxt = st.back().node, parent = st.back().parent, first_vis = st.back().first_vis; st.pop_back(); if (first_vis) { st.PB(state(nxt, parent, 0)); FORE(it, v[nxt]) if (*it != parent && !vis[*it]) { vis[*it] = 1; st.PB(state(*it, nxt, 1)); } } else { dfs(nxt, parent); } } cout << res << endl; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | #include <iomanip> #include <iostream> #include <utility> #include <algorithm> #include <cassert> #include <string> #include <vector> #include <set> #include <map> using namespace std; #define ALL(x) x.begin(), x.end() #define VAR(a,b) __typeof (b) a = b #define IN(a) int a; cin >> a #define IN2(a,b) int a, b; cin >> a >> b #define REP(i,n) for (int _n=(n), i=0; i<_n; ++i) #define FOR(i,a,b) for (int _b=(b), i=(a); i<=_b; ++i) #define FORD(i,a,b) for (int _b=(b), i=(a); i>=_b; --i) #define FORE(i,a) for (VAR(i,a.begin ()); i!=a.end (); ++i) #define PB push_back #define MP make_pair #define ST first #define ND second typedef vector<int> VI; typedef long long LL; typedef pair<int,int> PII; typedef double LD; const int DBG = 0, INF = int(1e9); vector<VI> v; VI r; int n, m; LL res = 0; vector<PII> res_range; void dfs(int node, int parent) { if (node < m) { res_range[node] = MP(r[node], r[node]); return; } vector<PII> children; children.reserve(v[node].size()); FORE(it, v[node]) if (*it != parent) children.PB(res_range[*it]); LL beg_sum = 0, end_sum = 0; int begs_left = children.size(), ends_cnt = 0; vector<PII> events; FORE(it, children) { events.PB(MP(it->first, 1)); beg_sum += it->first; events.PB(MP(it->second, 0)); } sort(ALL(events)); LL best_val = beg_sum; int best_beg = 0, best_end = 0; FORE(it, events) { int pos = it->first, type = it->second; LL cur = beg_sum - LL(begs_left) * pos + LL(ends_cnt) * pos - end_sum; if (cur < best_val) { best_val = cur; best_beg = best_end = pos; } else if (cur == best_val) best_end = pos; if (type == 1) { beg_sum -= pos; begs_left--; } else { end_sum += pos; ends_cnt++; } } res += best_val; res_range[node] = MP(best_beg, best_end); } VI vis; struct state { int node, parent, first_vis; state(int node, int parent, int first_vis) : node(node), parent(parent), first_vis(first_vis) {} }; int main() { ios_base::sync_with_stdio(0); cout.setf(ios::fixed); cin >> n >> m; v.resize(n); REP(i, n - 1) { IN2(a, b); --a; --b; v[a].PB(b); v[b].PB(a); } r.resize(m); REP(i,m) cin >> r[i]; if (n == 2) { assert(m == 2); cout << abs(r[0] - r[1]) << endl; return 0; } vis = VI(n, 0); vector<state> st; st.PB(state(m, -1, 1)); vis[m] = 1; res_range.resize(n); while (!st.empty()) { int nxt = st.back().node, parent = st.back().parent, first_vis = st.back().first_vis; st.pop_back(); if (first_vis) { st.PB(state(nxt, parent, 0)); FORE(it, v[nxt]) if (*it != parent && !vis[*it]) { vis[*it] = 1; st.PB(state(*it, nxt, 1)); } } else { dfs(nxt, parent); } } cout << res << endl; return 0; } |