#include <cstdio> #include <cstdlib> #include <vector> #include <algorithm> int n, m, n1; int r[500000]; std::vector<int> mapa[500000]; long long wynik = 0; void oblicz(int v, int rodzic, int &a, int &b) { if (v < m) { a = r[v]; b = r[v]; return; } std::vector<int> av; std::vector<int> bv; std::vector<int> tmp; for (std::vector<int>::iterator it = mapa[v].begin(); it != mapa[v].end(); ++it) { if (*it == rodzic) { continue; } int at, bt; oblicz(*it, v, at, bt); av.push_back(at); bv.push_back(bt); tmp.push_back(at); tmp.push_back(bt); } std::sort(tmp.begin(), tmp.end()); a = tmp[(tmp.size() >> 1) - 1]; b = tmp[tmp.size() >> 1]; std::vector<int>::iterator at = av.begin(); std::vector<int>::iterator bt = bv.begin(); while (at != av.end()) { if (a < *at) { wynik += *at - a; } else if (*bt < a) { wynik += a - *bt; } ++at; ++bt; } } int main() { scanf("%d%d", &n, &m); n1 = n - 1; for (int i = 0; i < n1; ++i) { int a, b; scanf("%d%d", &a, &b); --a; --b; mapa[a].push_back(b); mapa[b].push_back(a); } for (int i = 0; i < m; ++i) { scanf("%d", r + i); } if (n == 2 && m == 2) { printf("%d", abs(r[0] - r[1])); return 0; } int aa, bb; oblicz(m, -1, aa, bb); printf("%lld", wynik); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | #include <cstdio> #include <cstdlib> #include <vector> #include <algorithm> int n, m, n1; int r[500000]; std::vector<int> mapa[500000]; long long wynik = 0; void oblicz(int v, int rodzic, int &a, int &b) { if (v < m) { a = r[v]; b = r[v]; return; } std::vector<int> av; std::vector<int> bv; std::vector<int> tmp; for (std::vector<int>::iterator it = mapa[v].begin(); it != mapa[v].end(); ++it) { if (*it == rodzic) { continue; } int at, bt; oblicz(*it, v, at, bt); av.push_back(at); bv.push_back(bt); tmp.push_back(at); tmp.push_back(bt); } std::sort(tmp.begin(), tmp.end()); a = tmp[(tmp.size() >> 1) - 1]; b = tmp[tmp.size() >> 1]; std::vector<int>::iterator at = av.begin(); std::vector<int>::iterator bt = bv.begin(); while (at != av.end()) { if (a < *at) { wynik += *at - a; } else if (*bt < a) { wynik += a - *bt; } ++at; ++bt; } } int main() { scanf("%d%d", &n, &m); n1 = n - 1; for (int i = 0; i < n1; ++i) { int a, b; scanf("%d%d", &a, &b); --a; --b; mapa[a].push_back(b); mapa[b].push_back(a); } for (int i = 0; i < m; ++i) { scanf("%d", r + i); } if (n == 2 && m == 2) { printf("%d", abs(r[0] - r[1])); return 0; } int aa, bb; oblicz(m, -1, aa, bb); printf("%lld", wynik); return 0; } |