#include <cstdio>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const ll inf = 1e12;
ll ans;
const int N = 500001;
vector<int> g[N];
int r[N];
pair<int, int> dfs(int v, int prev)
{
if(r[v]) return make_pair(r[v], r[v]);
vector<pair<int, bool>> vec;
ll current = 0;
int point = 0;
for(int u: g[v])
if(u != prev)
{
auto p = dfs(u, v);
vec.emplace_back(p.first, 0);
vec.emplace_back(p.second, 1);
current += p.first;
}
sort(vec.begin(), vec.end());
ll ans = current;
int first = 0, last = 0, pre = 0, past = vec.size() / 2;
for(auto p: vec)
{
current += (ll)(pre - past) * (p.first - point);
point = p.first;
if(current < ans)
{
ans = current;
first = point;
}
if(current == ans) last = point;
if(p.second) pre++;
else past--;
}
::ans += ans;
return make_pair(first, last);
}
int main()
{
int n, m;
scanf("%d %d", &n, &m);
for(int i = 1; i < n; i++)
{
int a, b;
scanf("%d %d", &a, &b);
g[a].push_back(b);
g[b].push_back(a);
}
for(int i = 1; i <= m; i++)
scanf("%d", r + i);
if(n == m)
printf("%d\n", abs(r[1] - r[2]));
else
{
dfs(m + 1, 0);
printf("%lld\n", ans);
}
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | #include <cstdio> #include <algorithm> #include <vector> #include <queue> using namespace std; typedef long long ll; const ll inf = 1e12; ll ans; const int N = 500001; vector<int> g[N]; int r[N]; pair<int, int> dfs(int v, int prev) { if(r[v]) return make_pair(r[v], r[v]); vector<pair<int, bool>> vec; ll current = 0; int point = 0; for(int u: g[v]) if(u != prev) { auto p = dfs(u, v); vec.emplace_back(p.first, 0); vec.emplace_back(p.second, 1); current += p.first; } sort(vec.begin(), vec.end()); ll ans = current; int first = 0, last = 0, pre = 0, past = vec.size() / 2; for(auto p: vec) { current += (ll)(pre - past) * (p.first - point); point = p.first; if(current < ans) { ans = current; first = point; } if(current == ans) last = point; if(p.second) pre++; else past--; } ::ans += ans; return make_pair(first, last); } int main() { int n, m; scanf("%d %d", &n, &m); for(int i = 1; i < n; i++) { int a, b; scanf("%d %d", &a, &b); g[a].push_back(b); g[b].push_back(a); } for(int i = 1; i <= m; i++) scanf("%d", r + i); if(n == m) printf("%d\n", abs(r[1] - r[2])); else { dfs(m + 1, 0); printf("%lld\n", ans); } } |
English