#include <algorithm> #include <cstdio> #include <queue> #include <vector> using namespace std; typedef long long i64; struct Problem { Problem(int n, int m) : n_(n), m_(m), adj_(n), r_(n,{-1,-1}) {} void addEdge(int u, int v) { adj_[u].push_back(v); adj_[v].push_back(u); } i64 solve(); int n_,m_; vector<vector<int>> adj_; vector<pair<int,int>> r_; }; i64 Problem::solve() { vector<int> deg(n_); for (int i=0; i<n_; ++i) deg[i] = (int)adj_[i].size(); queue<int> q; for (int i=0; i<m_; ++i) if (--deg[adj_[i][0]] == 1) q.push(adj_[i][0]); vector<int> st; while (!q.empty()) { int v = q.front(); q.pop(); st.push_back(v); vector<int> ar; bool all = true; for (int u : adj_[v]) { if (r_[u].first == -1) { all = false; if (--deg[u] == 1) q.push(u); continue; } ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); if (all || ar.size()%2 == 1) r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; else r_[v] = {ar[((int)ar.size()-1)/2], ar[(int)ar.size()/2]}; } for (auto it = st.rbegin(); it != st.rend(); ++it) { int v = *it; vector<int> ar; for (int u : adj_[v]) { ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; } i64 result = 0; for (int i=0; i<n_; ++i) for (int v : adj_[i]) if (r_[i].first > r_[v].first) result += r_[i].first - r_[v].first; return result; } int main() { int n,m; scanf("%d%d", &n, &m); Problem p(n,m); for (int i=0; i<n-1; ++i) { int u,v; scanf("%d%d", &u, &v); p.addEdge(u-1,v-1); } for (int i=0; i<m; ++i) { int x; scanf("%d", &x); p.r_[i] = {x,x}; } printf("%lld\n", p.solve()); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | #include <algorithm> #include <cstdio> #include <queue> #include <vector> using namespace std; typedef long long i64; struct Problem { Problem(int n, int m) : n_(n), m_(m), adj_(n), r_(n,{-1,-1}) {} void addEdge(int u, int v) { adj_[u].push_back(v); adj_[v].push_back(u); } i64 solve(); int n_,m_; vector<vector<int>> adj_; vector<pair<int,int>> r_; }; i64 Problem::solve() { vector<int> deg(n_); for (int i=0; i<n_; ++i) deg[i] = (int)adj_[i].size(); queue<int> q; for (int i=0; i<m_; ++i) if (--deg[adj_[i][0]] == 1) q.push(adj_[i][0]); vector<int> st; while (!q.empty()) { int v = q.front(); q.pop(); st.push_back(v); vector<int> ar; bool all = true; for (int u : adj_[v]) { if (r_[u].first == -1) { all = false; if (--deg[u] == 1) q.push(u); continue; } ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); if (all || ar.size()%2 == 1) r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; else r_[v] = {ar[((int)ar.size()-1)/2], ar[(int)ar.size()/2]}; } for (auto it = st.rbegin(); it != st.rend(); ++it) { int v = *it; vector<int> ar; for (int u : adj_[v]) { ar.push_back(r_[u].first); ar.push_back(r_[u].second); } sort(ar.begin(), ar.end()); r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]}; } i64 result = 0; for (int i=0; i<n_; ++i) for (int v : adj_[i]) if (r_[i].first > r_[v].first) result += r_[i].first - r_[v].first; return result; } int main() { int n,m; scanf("%d%d", &n, &m); Problem p(n,m); for (int i=0; i<n-1; ++i) { int u,v; scanf("%d%d", &u, &v); p.addEdge(u-1,v-1); } for (int i=0; i<m; ++i) { int x; scanf("%d", &x); p.r_[i] = {x,x}; } printf("%lld\n", p.solve()); } |