#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
int n, m, n1;
int r[500000];
std::vector<int> mapa[500000];
long long wynik = 0;
void oblicz(int v, int rodzic, int &a, int &b)
{
if (v < m)
{
a = r[v];
b = r[v];
return;
}
std::vector<int> av;
std::vector<int> bv;
std::vector<int> tmp;
for (std::vector<int>::iterator it = mapa[v].begin(); it != mapa[v].end(); ++it)
{
if (*it == rodzic) {
continue;
}
int at, bt;
oblicz(*it, v, at, bt);
av.push_back(at);
bv.push_back(bt);
tmp.push_back(at);
tmp.push_back(bt);
}
std::sort(tmp.begin(), tmp.end());
a = tmp[(tmp.size() >> 1) - 1];
b = tmp[tmp.size() >> 1];
std::vector<int>::iterator at = av.begin();
std::vector<int>::iterator bt = bv.begin();
while (at != av.end())
{
if (a < *at) {
wynik += *at - a;
}
else if (*bt < a) {
wynik += a - *bt;
}
++at;
++bt;
}
}
int main()
{
scanf("%d%d", &n, &m);
n1 = n - 1;
for (int i = 0; i < n1; ++i)
{
int a, b;
scanf("%d%d", &a, &b);
--a; --b;
mapa[a].push_back(b);
mapa[b].push_back(a);
}
for (int i = 0; i < m; ++i) {
scanf("%d", r + i);
}
if (n == 2 && m == 2)
{
printf("%d", abs(r[0] - r[1]));
return 0;
}
int aa, bb;
oblicz(m, -1, aa, bb);
printf("%lld", wynik);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | #include <cstdio> #include <cstdlib> #include <vector> #include <algorithm> int n, m, n1; int r[500000]; std::vector<int> mapa[500000]; long long wynik = 0; void oblicz(int v, int rodzic, int &a, int &b) { if (v < m) { a = r[v]; b = r[v]; return; } std::vector<int> av; std::vector<int> bv; std::vector<int> tmp; for (std::vector<int>::iterator it = mapa[v].begin(); it != mapa[v].end(); ++it) { if (*it == rodzic) { continue; } int at, bt; oblicz(*it, v, at, bt); av.push_back(at); bv.push_back(bt); tmp.push_back(at); tmp.push_back(bt); } std::sort(tmp.begin(), tmp.end()); a = tmp[(tmp.size() >> 1) - 1]; b = tmp[tmp.size() >> 1]; std::vector<int>::iterator at = av.begin(); std::vector<int>::iterator bt = bv.begin(); while (at != av.end()) { if (a < *at) { wynik += *at - a; } else if (*bt < a) { wynik += a - *bt; } ++at; ++bt; } } int main() { scanf("%d%d", &n, &m); n1 = n - 1; for (int i = 0; i < n1; ++i) { int a, b; scanf("%d%d", &a, &b); --a; --b; mapa[a].push_back(b); mapa[b].push_back(a); } for (int i = 0; i < m; ++i) { scanf("%d", r + i); } if (n == 2 && m == 2) { printf("%d", abs(r[0] - r[1])); return 0; } int aa, bb; oblicz(m, -1, aa, bb); printf("%lld", wynik); return 0; } |
English