#include <cstdio> #include <cassert> #include <vector> #include <set> #include <map> #include <cinttypes> #include <cstdint> #include <algorithm> #include <utility> using namespace std; int n, foreign; set<int> visited; vector<vector<int> > g; vector<int> roz; struct res { pair<int, int> rozs; uint64_t cost; }; int abs(int i ) { return i >= 0?i:-i; } res visit(int node) { visited.insert(node); vector<int> rozs; vector<pair<int, int>> rozss; uint64_t cost = 0; for (int s : g[node]) { if (visited.count(s)) { continue; } if (s < foreign) { rozs.push_back(roz[s]); rozss.push_back(make_pair(roz[s], roz[s])); } else { res r = visit(s); rozs.push_back(r.rozs.first); rozs.push_back(r.rozs.second); rozss.push_back(r.rozs); cost += r.cost; } } sort(rozs.begin(), rozs.end()); // fprintf(stderr, "%d: ", node+1); // fprintf(stderr, "rozs: ("); // for (int r : rozs) { // fprintf(stderr, "%d ", r); // } // fprintf(stderr, ")"); res r; assert(!rozs.empty()); if (rozs.size() % 2) { r.rozs.first = r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d) ", r.rozs.first); } else { r.rozs.first = rozs[rozs.size() / 2 - 1]; r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d,%d) ", r.rozs.first, r.rozs.second); } r.cost = cost; for (pair<int, int> roz : rozss) { if (r.rozs.first >= roz.first && r.rozs.first <= roz.second) { continue; } r.cost += min(abs(roz.first - r.rozs.first), abs(roz.second - r.rozs.first)); } // fprintf(stderr, "%" PRIu64 "\n", r.cost); return r; } int main() { scanf("%d%d",&n,&foreign); g.resize(n); for(int i =0; i < n- 1; ++i) { int u,v; scanf("%d%d",&u,&v); g[u-1].push_back(v-1); g[v-1].push_back(u-1); } roz.resize(foreign); for(int i =0; i<foreign; ++i) { scanf("%d",&roz[i]); } uint64_t res = foreign < g.size() ? visit(foreign).cost : 0; for (int i = 0; i < n ; ++i) { if (visited.count(i) == 0) { visited.insert(i); for(int n : g[i]) { if (visited.count(n) == 0) { visited.insert(n); res += abs(roz[i] - roz[n]); } } } } printf("%" PRIu64 "\n", res); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | #include <cstdio> #include <cassert> #include <vector> #include <set> #include <map> #include <cinttypes> #include <cstdint> #include <algorithm> #include <utility> using namespace std; int n, foreign; set<int> visited; vector<vector<int> > g; vector<int> roz; struct res { pair<int, int> rozs; uint64_t cost; }; int abs(int i ) { return i >= 0?i:-i; } res visit(int node) { visited.insert(node); vector<int> rozs; vector<pair<int, int>> rozss; uint64_t cost = 0; for (int s : g[node]) { if (visited.count(s)) { continue; } if (s < foreign) { rozs.push_back(roz[s]); rozss.push_back(make_pair(roz[s], roz[s])); } else { res r = visit(s); rozs.push_back(r.rozs.first); rozs.push_back(r.rozs.second); rozss.push_back(r.rozs); cost += r.cost; } } sort(rozs.begin(), rozs.end()); // fprintf(stderr, "%d: ", node+1); // fprintf(stderr, "rozs: ("); // for (int r : rozs) { // fprintf(stderr, "%d ", r); // } // fprintf(stderr, ")"); res r; assert(!rozs.empty()); if (rozs.size() % 2) { r.rozs.first = r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d) ", r.rozs.first); } else { r.rozs.first = rozs[rozs.size() / 2 - 1]; r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d,%d) ", r.rozs.first, r.rozs.second); } r.cost = cost; for (pair<int, int> roz : rozss) { if (r.rozs.first >= roz.first && r.rozs.first <= roz.second) { continue; } r.cost += min(abs(roz.first - r.rozs.first), abs(roz.second - r.rozs.first)); } // fprintf(stderr, "%" PRIu64 "\n", r.cost); return r; } int main() { scanf("%d%d",&n,&foreign); g.resize(n); for(int i =0; i < n- 1; ++i) { int u,v; scanf("%d%d",&u,&v); g[u-1].push_back(v-1); g[v-1].push_back(u-1); } roz.resize(foreign); for(int i =0; i<foreign; ++i) { scanf("%d",&roz[i]); } uint64_t res = foreign < g.size() ? visit(foreign).cost : 0; for (int i = 0; i < n ; ++i) { if (visited.count(i) == 0) { visited.insert(i); for(int n : g[i]) { if (visited.count(n) == 0) { visited.insert(n); res += abs(roz[i] - roz[n]); } } } } printf("%" PRIu64 "\n", res); } |