#include <cstdio> #include <vector> #include <algorithm> using namespace std; vector <int> kr[500005]; int n,m,rozstaw[500005]; pair <int,int> postulat[500005]; pair <int,int> dfs(int start, int pop){ if (kr[start].size() == 1){ return make_pair(rozstaw[start], rozstaw[start]); } vector < pair <int,int> > prz; vector < pair <int, pair <int, int> > > tab; int op = 0, after_me = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ prz.push_back(dfs(kr[start][i], start)); tab.push_back(make_pair(prz[prz.size()-1].first, make_pair(0, prz.size()-1))); tab.push_back(make_pair(prz[prz.size()-1].second, make_pair(1, prz.size()-1))); } } sort(tab.begin(), tab.end()); pair <int,int> ret = make_pair(1000005, -1); for (int i = 0; i < tab.size(); i++){ bool we_are_before = false, we_are_after = false; if (after_me < prz.size() - after_me - op){ we_are_before = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if (tab[i].second.first == 0){ op++; } else { op--; after_me++; } if (after_me > prz.size() - after_me - op){ we_are_after = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if ((we_are_before) && (we_are_after)){ ret = make_pair(tab[i].first, tab[i].first); } } postulat[start] = ret; return ret; } long long dfs2(int start, int pop, int value){ if (kr[start].size() == 1){ return 0; } long long ret = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ if (value > postulat[ kr[start][i] ].second){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].second); ret += value - postulat[ kr[start][i] ].second; } else if (value < postulat[ kr[start][i] ].first){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].first); ret += postulat[ kr[start][i] ].first - value; } else { ret += dfs2(kr[start][i], start, value); } } } return ret; } int main(){ scanf("%d%d",&n,&m); for (int i = 0; i < n-1; i++){ int a,b; scanf("%d%d",&a,&b); kr[a].push_back(b); kr[b].push_back(a); } for (int i = 1; i <= m; i++){ scanf("%d", &rozstaw[i]); postulat[i] = make_pair(rozstaw[i], rozstaw[i]); } if ((n == 2) && (m == 2)){ if (rozstaw[1] > rozstaw[2]) printf("%d\n", rozstaw[1] - rozstaw[2]); else printf("%d\n", rozstaw[2] - rozstaw[1]); return 0; } dfs(n, -1); printf("%lld\n", dfs2(n,-1, postulat[n].first)); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | #include <cstdio> #include <vector> #include <algorithm> using namespace std; vector <int> kr[500005]; int n,m,rozstaw[500005]; pair <int,int> postulat[500005]; pair <int,int> dfs(int start, int pop){ if (kr[start].size() == 1){ return make_pair(rozstaw[start], rozstaw[start]); } vector < pair <int,int> > prz; vector < pair <int, pair <int, int> > > tab; int op = 0, after_me = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ prz.push_back(dfs(kr[start][i], start)); tab.push_back(make_pair(prz[prz.size()-1].first, make_pair(0, prz.size()-1))); tab.push_back(make_pair(prz[prz.size()-1].second, make_pair(1, prz.size()-1))); } } sort(tab.begin(), tab.end()); pair <int,int> ret = make_pair(1000005, -1); for (int i = 0; i < tab.size(); i++){ bool we_are_before = false, we_are_after = false; if (after_me < prz.size() - after_me - op){ we_are_before = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if (tab[i].second.first == 0){ op++; } else { op--; after_me++; } if (after_me > prz.size() - after_me - op){ we_are_after = true; } if (after_me == prz.size() - after_me - op){ ret.first = min(ret.first, tab[i].first); ret.second = max(ret.second, tab[i].first); } if ((we_are_before) && (we_are_after)){ ret = make_pair(tab[i].first, tab[i].first); } } postulat[start] = ret; return ret; } long long dfs2(int start, int pop, int value){ if (kr[start].size() == 1){ return 0; } long long ret = 0; for (int i = 0; i < kr[start].size(); i++){ if (kr[start][i] != pop){ if (value > postulat[ kr[start][i] ].second){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].second); ret += value - postulat[ kr[start][i] ].second; } else if (value < postulat[ kr[start][i] ].first){ ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].first); ret += postulat[ kr[start][i] ].first - value; } else { ret += dfs2(kr[start][i], start, value); } } } return ret; } int main(){ scanf("%d%d",&n,&m); for (int i = 0; i < n-1; i++){ int a,b; scanf("%d%d",&a,&b); kr[a].push_back(b); kr[b].push_back(a); } for (int i = 1; i <= m; i++){ scanf("%d", &rozstaw[i]); postulat[i] = make_pair(rozstaw[i], rozstaw[i]); } if ((n == 2) && (m == 2)){ if (rozstaw[1] > rozstaw[2]) printf("%d\n", rozstaw[1] - rozstaw[2]); else printf("%d\n", rozstaw[2] - rozstaw[1]); return 0; } dfs(n, -1); printf("%lld\n", dfs2(n,-1, postulat[n].first)); } |