#include <iomanip>
#include <iostream>
#include <utility>
#include <algorithm>
#include <cassert>
#include <string>
#include <vector>
#include <set>
#include <map>
using namespace std;
#define ALL(x) x.begin(), x.end()
#define VAR(a,b) __typeof (b) a = b
#define IN(a) int a; cin >> a
#define IN2(a,b) int a, b; cin >> a >> b
#define REP(i,n) for (int _n=(n), i=0; i<_n; ++i)
#define FOR(i,a,b) for (int _b=(b), i=(a); i<=_b; ++i)
#define FORD(i,a,b) for (int _b=(b), i=(a); i>=_b; --i)
#define FORE(i,a) for (VAR(i,a.begin ()); i!=a.end (); ++i)
#define PB push_back
#define MP make_pair
#define ST first
#define ND second
typedef vector<int> VI;
typedef long long LL;
typedef pair<int,int> PII;
typedef double LD;
const int DBG = 0, INF = int(1e9);
vector<VI> v;
VI r;
int n, m;
LL res = 0;
vector<PII> res_range;
void dfs(int node, int parent) {
if (node < m) { res_range[node] = MP(r[node], r[node]); return; }
vector<PII> children;
children.reserve(v[node].size());
FORE(it, v[node])
if (*it != parent)
children.PB(res_range[*it]);
LL beg_sum = 0, end_sum = 0;
int begs_left = children.size(), ends_cnt = 0;
vector<PII> events;
FORE(it, children) {
events.PB(MP(it->first, 1));
beg_sum += it->first;
events.PB(MP(it->second, 0));
}
sort(ALL(events));
LL best_val = beg_sum;
int best_beg = 0, best_end = 0;
FORE(it, events) {
int pos = it->first, type = it->second;
LL cur = beg_sum - LL(begs_left) * pos + LL(ends_cnt) * pos - end_sum;
if (cur < best_val) {
best_val = cur;
best_beg = best_end = pos;
}
else if (cur == best_val) best_end = pos;
if (type == 1) {
beg_sum -= pos;
begs_left--;
}
else {
end_sum += pos;
ends_cnt++;
}
}
res += best_val;
res_range[node] = MP(best_beg, best_end);
}
VI vis;
struct state {
int node, parent, first_vis;
state(int node, int parent, int first_vis) : node(node), parent(parent), first_vis(first_vis) {}
};
int main() {
ios_base::sync_with_stdio(0);
cout.setf(ios::fixed);
cin >> n >> m;
v.resize(n);
REP(i, n - 1) {
IN2(a, b);
--a;
--b;
v[a].PB(b);
v[b].PB(a);
}
r.resize(m);
REP(i,m) cin >> r[i];
if (n == 2) {
assert(m == 2);
cout << abs(r[0] - r[1]) << endl;
return 0;
}
vis = VI(n, 0);
vector<state> st;
st.PB(state(m, -1, 1));
vis[m] = 1;
res_range.resize(n);
while (!st.empty()) {
int nxt = st.back().node, parent = st.back().parent, first_vis = st.back().first_vis;
st.pop_back();
if (first_vis) {
st.PB(state(nxt, parent, 0));
FORE(it, v[nxt])
if (*it != parent && !vis[*it]) {
vis[*it] = 1;
st.PB(state(*it, nxt, 1));
}
}
else {
dfs(nxt, parent);
}
}
cout << res << endl;
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | #include <iomanip> #include <iostream> #include <utility> #include <algorithm> #include <cassert> #include <string> #include <vector> #include <set> #include <map> using namespace std; #define ALL(x) x.begin(), x.end() #define VAR(a,b) __typeof (b) a = b #define IN(a) int a; cin >> a #define IN2(a,b) int a, b; cin >> a >> b #define REP(i,n) for (int _n=(n), i=0; i<_n; ++i) #define FOR(i,a,b) for (int _b=(b), i=(a); i<=_b; ++i) #define FORD(i,a,b) for (int _b=(b), i=(a); i>=_b; --i) #define FORE(i,a) for (VAR(i,a.begin ()); i!=a.end (); ++i) #define PB push_back #define MP make_pair #define ST first #define ND second typedef vector<int> VI; typedef long long LL; typedef pair<int,int> PII; typedef double LD; const int DBG = 0, INF = int(1e9); vector<VI> v; VI r; int n, m; LL res = 0; vector<PII> res_range; void dfs(int node, int parent) { if (node < m) { res_range[node] = MP(r[node], r[node]); return; } vector<PII> children; children.reserve(v[node].size()); FORE(it, v[node]) if (*it != parent) children.PB(res_range[*it]); LL beg_sum = 0, end_sum = 0; int begs_left = children.size(), ends_cnt = 0; vector<PII> events; FORE(it, children) { events.PB(MP(it->first, 1)); beg_sum += it->first; events.PB(MP(it->second, 0)); } sort(ALL(events)); LL best_val = beg_sum; int best_beg = 0, best_end = 0; FORE(it, events) { int pos = it->first, type = it->second; LL cur = beg_sum - LL(begs_left) * pos + LL(ends_cnt) * pos - end_sum; if (cur < best_val) { best_val = cur; best_beg = best_end = pos; } else if (cur == best_val) best_end = pos; if (type == 1) { beg_sum -= pos; begs_left--; } else { end_sum += pos; ends_cnt++; } } res += best_val; res_range[node] = MP(best_beg, best_end); } VI vis; struct state { int node, parent, first_vis; state(int node, int parent, int first_vis) : node(node), parent(parent), first_vis(first_vis) {} }; int main() { ios_base::sync_with_stdio(0); cout.setf(ios::fixed); cin >> n >> m; v.resize(n); REP(i, n - 1) { IN2(a, b); --a; --b; v[a].PB(b); v[b].PB(a); } r.resize(m); REP(i,m) cin >> r[i]; if (n == 2) { assert(m == 2); cout << abs(r[0] - r[1]) << endl; return 0; } vis = VI(n, 0); vector<state> st; st.PB(state(m, -1, 1)); vis[m] = 1; res_range.resize(n); while (!st.empty()) { int nxt = st.back().node, parent = st.back().parent, first_vis = st.back().first_vis; st.pop_back(); if (first_vis) { st.PB(state(nxt, parent, 0)); FORE(it, v[nxt]) if (*it != parent && !vis[*it]) { vis[*it] = 1; st.PB(state(*it, nxt, 1)); } } else { dfs(nxt, parent); } } cout << res << endl; return 0; } |
English