#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
typedef long long i64;
struct Problem
{
Problem(int n, int m) : n_(n), m_(m), adj_(n), r_(n,{-1,-1}) {}
void addEdge(int u, int v) { adj_[u].push_back(v); adj_[v].push_back(u); }
i64 solve();
int n_,m_;
vector<vector<int>> adj_;
vector<pair<int,int>> r_;
};
i64 Problem::solve()
{
vector<int> deg(n_);
for (int i=0; i<n_; ++i)
deg[i] = (int)adj_[i].size();
queue<int> q;
for (int i=0; i<m_; ++i)
if (--deg[adj_[i][0]] == 1)
q.push(adj_[i][0]);
vector<int> st;
while (!q.empty())
{
int v = q.front();
q.pop();
st.push_back(v);
vector<int> ar;
bool all = true;
for (int u : adj_[v])
{
if (r_[u].first == -1)
{
all = false;
if (--deg[u] == 1)
q.push(u);
continue;
}
ar.push_back(r_[u].first);
ar.push_back(r_[u].second);
}
sort(ar.begin(), ar.end());
if (all || ar.size()%2 == 1)
r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]};
else
r_[v] = {ar[((int)ar.size()-1)/2], ar[(int)ar.size()/2]};
}
for (auto it = st.rbegin(); it != st.rend(); ++it)
{
int v = *it;
vector<int> ar;
for (int u : adj_[v])
{
ar.push_back(r_[u].first);
ar.push_back(r_[u].second);
}
sort(ar.begin(), ar.end());
r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]};
}
i64 result = 0;
for (int i=0; i<n_; ++i)
for (int v : adj_[i])
if (r_[i].first > r_[v].first)
result += r_[i].first - r_[v].first;
return result;
}
int main()
{
int n,m;
scanf("%d%d", &n, &m);
Problem p(n,m);
for (int i=0; i<n-1; ++i)
{
int u,v;
scanf("%d%d", &u, &v);
p.addEdge(u-1,v-1);
}
for (int i=0; i<m; ++i)
{
int x;
scanf("%d", &x);
p.r_[i] = {x,x};
}
printf("%lld\n", p.solve());
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | #include <algorithm> #include <cstdio> #include <queue> #include <vector> using namespace std; typedef long long i64; struct Problem { Problem(int n, int m) : n_(n), m_(m), adj_(n), r_(n,{-1,-1}) {} void addEdge(int u, int v) { adj_[u].push_back(v); adj_[v].push_back(u); } i64 solve(); int n_,m_; vector<vector<int>> adj_; vector<pair<int,int>> r_; }; i64 Problem::solve() { vector<int> deg(n_); for (int i=0; i<n_; ++i) deg[i] = (int)adj_[i].size(); queue<int> q; for (int i=0; i<m_; ++i) if (--deg[adj_[i][0]] == 1) q.push(adj_[i][0]); vector<int> st; while (!q.empty()) { int v = q.front(); q.pop(); st.push_back(v); vector<int> ar; bool all = true; for (int u : adj_[v]) { if (r_[u].first == -1) { all = false; if (--deg[u] == 1) q.push(u); continue; } ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); if (all || ar.size()%2 == 1) r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; else r_[v] = {ar[((int)ar.size()-1)/2], ar[(int)ar.size()/2]}; } for (auto it = st.rbegin(); it != st.rend(); ++it) { int v = *it; vector<int> ar; for (int u : adj_[v]) { ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; } i64 result = 0; for (int i=0; i<n_; ++i) for (int v : adj_[i]) if (r_[i].first > r_[v].first) result += r_[i].first - r_[v].first; return result; } int main() { int n,m; scanf("%d%d", &n, &m); Problem p(n,m); for (int i=0; i<n-1; ++i) { int u,v; scanf("%d%d", &u, &v); p.addEdge(u-1,v-1); } for (int i=0; i<m; ++i) { int x; scanf("%d", &x); p.r_[i] = {x,x}; } printf("%lld\n", p.solve()); } |
polski