#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
vector <int> kr[500005];
int n,m,rozstaw[500005];
pair <int,int> postulat[500005];
pair <int,int> dfs(int start, int pop){
if (kr[start].size() == 1){
return make_pair(rozstaw[start], rozstaw[start]);
}
vector < pair <int,int> > prz;
vector < pair <int, pair <int, int> > > tab;
int op = 0, after_me = 0;
for (int i = 0; i < kr[start].size(); i++){
if (kr[start][i] != pop){
prz.push_back(dfs(kr[start][i], start));
tab.push_back(make_pair(prz[prz.size()-1].first, make_pair(0, prz.size()-1)));
tab.push_back(make_pair(prz[prz.size()-1].second, make_pair(1, prz.size()-1)));
}
}
sort(tab.begin(), tab.end());
pair <int,int> ret = make_pair(1000005, -1);
for (int i = 0; i < tab.size(); i++){
bool we_are_before = false, we_are_after = false;
if (after_me < prz.size() - after_me - op){
we_are_before = true;
}
if (after_me == prz.size() - after_me - op){
ret.first = min(ret.first, tab[i].first);
ret.second = max(ret.second, tab[i].first);
}
if (tab[i].second.first == 0){
op++;
}
else {
op--;
after_me++;
}
if (after_me > prz.size() - after_me - op){
we_are_after = true;
}
if (after_me == prz.size() - after_me - op){
ret.first = min(ret.first, tab[i].first);
ret.second = max(ret.second, tab[i].first);
}
if ((we_are_before) && (we_are_after)){
ret = make_pair(tab[i].first, tab[i].first);
}
}
postulat[start] = ret;
return ret;
}
long long dfs2(int start, int pop, int value){
if (kr[start].size() == 1){
return 0;
}
long long ret = 0;
for (int i = 0; i < kr[start].size(); i++){
if (kr[start][i] != pop){
if (value > postulat[ kr[start][i] ].second){
ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].second);
ret += value - postulat[ kr[start][i] ].second;
}
else if (value < postulat[ kr[start][i] ].first){
ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].first);
ret += postulat[ kr[start][i] ].first - value;
}
else {
ret += dfs2(kr[start][i], start, value);
}
}
}
return ret;
}
int main(){
scanf("%d%d",&n,&m);
for (int i = 0; i < n-1; i++){
int a,b;
scanf("%d%d",&a,&b);
kr[a].push_back(b);
kr[b].push_back(a);
}
for (int i = 1; i <= m; i++){
scanf("%d", &rozstaw[i]);
postulat[i] = make_pair(rozstaw[i], rozstaw[i]);
}
if ((n == 2) && (m == 2)){
if (rozstaw[1] > rozstaw[2])
printf("%d\n", rozstaw[1] - rozstaw[2]);
else
printf("%d\n", rozstaw[2] - rozstaw[1]);
return 0;
}
dfs(n, -1);
printf("%lld\n", dfs2(n,-1, postulat[n].first));
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | #include <cstdio> #include <vector> #include <algorithm> using namespace std; vector <int> kr[500005]; int n,m,rozstaw[500005]; pair <int,int> postulat[500005]; pair <int,int> dfs(int start, int pop){ if (kr[start].size() == 1){ return make_pair(rozstaw[start], rozstaw[start]); } vector < pair <int,int> > prz; vector < pair <int, pair <int, int> > > tab; int op = 0, after_me = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ prz.push_back(dfs(kr[start][i], start)); tab.push_back(make_pair(prz[prz.size()-1].first, make_pair(0, prz.size()-1))); tab.push_back(make_pair(prz[prz.size()-1].second, make_pair(1, prz.size()-1))); } } sort(tab.begin(), tab.end()); pair <int,int> ret = make_pair(1000005, -1); for (int i = 0; i < tab.size(); i++){ bool we_are_before = false, we_are_after = false; if (after_me < prz.size() - after_me - op){ we_are_before = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if (tab[i].second.first == 0){ op++; } else { op--; after_me++; } if (after_me > prz.size() - after_me - op){ we_are_after = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if ((we_are_before) && (we_are_after)){ ret = make_pair(tab[i].first, tab[i].first); } } postulat[start] = ret; return ret; } long long dfs2(int start, int pop, int value){ if (kr[start].size() == 1){ return 0; } long long ret = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ if (value > postulat[ kr[start][i] ].second){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].second); ret += value - postulat[ kr[start][i] ].second; } else if (value < postulat[ kr[start][i] ].first){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].first); ret += postulat[ kr[start][i] ].first - value; } else { ret += dfs2(kr[start][i], start, value); } } } return ret; } int main(){ scanf("%d%d",&n,&m); for (int i = 0; i < n-1; i++){ int a,b; scanf("%d%d",&a,&b); kr[a].push_back(b); kr[b].push_back(a); } for (int i = 1; i <= m; i++){ scanf("%d", &rozstaw[i]); postulat[i] = make_pair(rozstaw[i], rozstaw[i]); } if ((n == 2) && (m == 2)){ if (rozstaw[1] > rozstaw[2]) printf("%d\n", rozstaw[1] - rozstaw[2]); else printf("%d\n", rozstaw[2] - rozstaw[1]); return 0; } dfs(n, -1); printf("%lld\n", dfs2(n,-1, postulat[n].first)); } |
English