#include <cstdio> #include <vector> #include <algorithm> int N,M; std::vector<int> edges[500100]; bool visited[500100]; long long int cost[500100]; int minv[500100]; int maxv[500100]; void calc(int v) { visited[v] = true; cost[v] = 0; std::vector<std::pair<int,int> > childs; std::vector<int> nr; if (v>M) { for (int i=0; i<edges[v].size(); ++i) { int n = edges[v][i]; if (!visited[n]) { calc(n); nr.push_back(n); childs.push_back(std::make_pair(minv[n],-1)); childs.push_back(std::make_pair(minv[n],0)); childs.push_back(std::make_pair(maxv[n],-1)); childs.push_back(std::make_pair(maxv[n],1)); cost[v]+=cost[n]; } } int right = childs.size()/4; int left = 0; int mindiff = right-left; std::sort(childs.begin(), childs.end()); minv[v] = maxv[v] = childs[0].first; for (int i=0;i<childs.size(); ++i) { if (childs[i].second == 1) left++; else if (childs[i].second == 0) right--; if (abs(right-left) < mindiff) { mindiff = abs(right-left); minv[v] = maxv[v] = childs[i].first; } else if (abs(right-left) == mindiff) { maxv[v] = childs[i].first; } } for (int i=0;i<nr.size();i++) { int n = nr[i]; if (minv[v] < minv[n]) cost[v] += minv[n]-minv[v]; else if (minv[v] > maxv[n]) cost[v] += minv[v]-maxv[n]; } } } int main() { scanf("%d %d",&N,&M); for (int i=1;i<N;++i) { int a,b; scanf("%d %d",&a,&b); edges[a].push_back(b); edges[b].push_back(a); visited[i]=false; } for (int i=1;i<=M;++i) { int w; scanf("%d",&w); minv[i]=maxv[i]=w; } calc(N); printf("%lld\n",cost[N]); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #include <cstdio> #include <vector> #include <algorithm> int N,M; std::vector<int> edges[500100]; bool visited[500100]; long long int cost[500100]; int minv[500100]; int maxv[500100]; void calc(int v) { visited[v] = true; cost[v] = 0; std::vector<std::pair<int,int> > childs; std::vector<int> nr; if (v>M) { for (int i=0; i<edges[v].size(); ++i) { int n = edges[v][i]; if (!visited[n]) { calc(n); nr.push_back(n); childs.push_back(std::make_pair(minv[n],-1)); childs.push_back(std::make_pair(minv[n],0)); childs.push_back(std::make_pair(maxv[n],-1)); childs.push_back(std::make_pair(maxv[n],1)); cost[v]+=cost[n]; } } int right = childs.size()/4; int left = 0; int mindiff = right-left; std::sort(childs.begin(), childs.end()); minv[v] = maxv[v] = childs[0].first; for (int i=0;i<childs.size(); ++i) { if (childs[i].second == 1) left++; else if (childs[i].second == 0) right--; if (abs(right-left) < mindiff) { mindiff = abs(right-left); minv[v] = maxv[v] = childs[i].first; } else if (abs(right-left) == mindiff) { maxv[v] = childs[i].first; } } for (int i=0;i<nr.size();i++) { int n = nr[i]; if (minv[v] < minv[n]) cost[v] += minv[n]-minv[v]; else if (minv[v] > maxv[n]) cost[v] += minv[v]-maxv[n]; } } } int main() { scanf("%d %d",&N,&M); for (int i=1;i<N;++i) { int a,b; scanf("%d %d",&a,&b); edges[a].push_back(b); edges[b].push_back(a); visited[i]=false; } for (int i=1;i<=M;++i) { int w; scanf("%d",&w); minv[i]=maxv[i]=w; } calc(N); printf("%lld\n",cost[N]); return 0; } |