#include <cstdio>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <cinttypes>
#include <cstdint>
#include <algorithm>
#include <utility>
using namespace std;
int n, foreign;
set<int> visited;
vector<vector<int> > g;
vector<int> roz;
struct res {
pair<int, int> rozs;
uint64_t cost;
};
int abs(int i ) {
return i >= 0?i:-i;
}
res visit(int node) {
visited.insert(node);
vector<int> rozs;
vector<pair<int, int>> rozss;
uint64_t cost = 0;
for (int s : g[node]) {
if (visited.count(s)) {
continue;
}
if (s < foreign) {
rozs.push_back(roz[s]);
rozss.push_back(make_pair(roz[s], roz[s]));
} else {
res r = visit(s);
rozs.push_back(r.rozs.first);
rozs.push_back(r.rozs.second);
rozss.push_back(r.rozs);
cost += r.cost;
}
}
sort(rozs.begin(), rozs.end());
// fprintf(stderr, "%d: ", node+1);
// fprintf(stderr, "rozs: (");
// for (int r : rozs) {
// fprintf(stderr, "%d ", r);
// }
// fprintf(stderr, ")");
res r;
assert(!rozs.empty());
if (rozs.size() % 2) {
r.rozs.first = r.rozs.second = rozs[rozs.size() / 2];
// fprintf(stderr, "(%d) ", r.rozs.first);
} else {
r.rozs.first = rozs[rozs.size() / 2 - 1];
r.rozs.second = rozs[rozs.size() / 2];
// fprintf(stderr, "(%d,%d) ", r.rozs.first, r.rozs.second);
}
r.cost = cost;
for (pair<int, int> roz : rozss) {
if (r.rozs.first >= roz.first && r.rozs.first <= roz.second) {
continue;
}
r.cost += min(abs(roz.first - r.rozs.first), abs(roz.second - r.rozs.first));
}
// fprintf(stderr, "%" PRIu64 "\n", r.cost);
return r;
}
int main() {
scanf("%d%d",&n,&foreign);
g.resize(n);
for(int i =0; i < n- 1; ++i) {
int u,v;
scanf("%d%d",&u,&v);
g[u-1].push_back(v-1);
g[v-1].push_back(u-1);
}
roz.resize(foreign);
for(int i =0; i<foreign; ++i) {
scanf("%d",&roz[i]);
}
uint64_t res = foreign < g.size() ? visit(foreign).cost : 0;
for (int i = 0; i < n ; ++i) {
if (visited.count(i) == 0) {
visited.insert(i);
for(int n : g[i]) {
if (visited.count(n) == 0) {
visited.insert(n);
res += abs(roz[i] - roz[n]);
}
}
}
}
printf("%" PRIu64 "\n", res);
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | #include <cstdio> #include <cassert> #include <vector> #include <set> #include <map> #include <cinttypes> #include <cstdint> #include <algorithm> #include <utility> using namespace std; int n, foreign; set<int> visited; vector<vector<int> > g; vector<int> roz; struct res { pair<int, int> rozs; uint64_t cost; }; int abs(int i ) { return i >= 0?i:-i; } res visit(int node) { visited.insert(node); vector<int> rozs; vector<pair<int, int>> rozss; uint64_t cost = 0; for (int s : g[node]) { if (visited.count(s)) { continue; } if (s < foreign) { rozs.push_back(roz[s]); rozss.push_back(make_pair(roz[s], roz[s])); } else { res r = visit(s); rozs.push_back(r.rozs.first); rozs.push_back(r.rozs.second); rozss.push_back(r.rozs); cost += r.cost; } } sort(rozs.begin(), rozs.end()); // fprintf(stderr, "%d: ", node+1); // fprintf(stderr, "rozs: ("); // for (int r : rozs) { // fprintf(stderr, "%d ", r); // } // fprintf(stderr, ")"); res r; assert(!rozs.empty()); if (rozs.size() % 2) { r.rozs.first = r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d) ", r.rozs.first); } else { r.rozs.first = rozs[rozs.size() / 2 - 1]; r.rozs.second = rozs[rozs.size() / 2]; // fprintf(stderr, "(%d,%d) ", r.rozs.first, r.rozs.second); } r.cost = cost; for (pair<int, int> roz : rozss) { if (r.rozs.first >= roz.first && r.rozs.first <= roz.second) { continue; } r.cost += min(abs(roz.first - r.rozs.first), abs(roz.second - r.rozs.first)); } // fprintf(stderr, "%" PRIu64 "\n", r.cost); return r; } int main() { scanf("%d%d",&n,&foreign); g.resize(n); for(int i =0; i < n- 1; ++i) { int u,v; scanf("%d%d",&u,&v); g[u-1].push_back(v-1); g[v-1].push_back(u-1); } roz.resize(foreign); for(int i =0; i<foreign; ++i) { scanf("%d",&roz[i]); } uint64_t res = foreign < g.size() ? visit(foreign).cost : 0; for (int i = 0; i < n ; ++i) { if (visited.count(i) == 0) { visited.insert(i); for(int n : g[i]) { if (visited.count(n) == 0) { visited.insert(n); res += abs(roz[i] - roz[n]); } } } } printf("%" PRIu64 "\n", res); } |
English