#include<bits/stdc++.h> using namespace std; vector<int> kraw[500001]; long long wagi[500001]; int odw[500001]; int ojc[500001]; long long pr[500001][2]; long long wyn=0; void dfs(int x) { odw[x]=1; long long sum=0; long long no=0,za=0, po=0, ko=0, naj; if(kraw[x].size()==1) { pr[x][0]=wagi[x]; pr[x][1]=wagi[x]; return; } vector<pair <long long , int> > tab; for(vector<int> :: iterator it=kraw[x].begin(); it!=kraw[x].end(); it++) if(odw[*it]==0) { dfs(*it); sum+=pr[*it][0]; no++; tab.push_back(make_pair(pr[*it][0], 1)); tab.push_back(make_pair(pr[*it][1], -1)); } sort(tab.begin(), tab.end()); naj=sum; long long ost=0; for(int i=0; i<tab.size(); i++){ long long y=tab[i].first-ost; //if(x==6) // printf("? %lld %d\n", tab[i].first, tab[i].second); ost=tab[i].first; sum+=-no*y+za*y; if(sum<naj) { naj=sum; po=ost; } if(sum<=naj) ko=ost; if(tab[i].second == 1) { no--; } if(tab[i].second == -1) { za++; } //if(x==6) // printf("%lld %lld %lld %lld\n", ost, sum, no, za); } wyn+=naj; pr[x][0]=po; pr[x][1]=ko; return; } int main() { int n,m; scanf("%d%d", &n, &m); for(int i=1; i<n; i++) { int x,y; scanf("%d%d", &x, &y); kraw[x].push_back(y); kraw[y].push_back(x); } for(int i=1; i<=m; i++) scanf("%lld", &wagi[i]); if(n==2) { if(wagi[1]<wagi[2]) swap(wagi[1], wagi[2]); printf("%lld\n", wagi[1]-wagi[2]); return 0; } dfs(n); printf("%lld\n", wyn); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #include<bits/stdc++.h> using namespace std; vector<int> kraw[500001]; long long wagi[500001]; int odw[500001]; int ojc[500001]; long long pr[500001][2]; long long wyn=0; void dfs(int x) { odw[x]=1; long long sum=0; long long no=0,za=0, po=0, ko=0, naj; if(kraw[x].size()==1) { pr[x][0]=wagi[x]; pr[x][1]=wagi[x]; return; } vector<pair <long long , int> > tab; for(vector<int> :: iterator it=kraw[x].begin(); it!=kraw[x].end(); it++) if(odw[*it]==0) { dfs(*it); sum+=pr[*it][0]; no++; tab.push_back(make_pair(pr[*it][0], 1)); tab.push_back(make_pair(pr[*it][1], -1)); } sort(tab.begin(), tab.end()); naj=sum; long long ost=0; for(int i=0; i<tab.size(); i++){ long long y=tab[i].first-ost; //if(x==6) // printf("? %lld %d\n", tab[i].first, tab[i].second); ost=tab[i].first; sum+=-no*y+za*y; if(sum<naj) { naj=sum; po=ost; } if(sum<=naj) ko=ost; if(tab[i].second == 1) { no--; } if(tab[i].second == -1) { za++; } //if(x==6) // printf("%lld %lld %lld %lld\n", ost, sum, no, za); } wyn+=naj; pr[x][0]=po; pr[x][1]=ko; return; } int main() { int n,m; scanf("%d%d", &n, &m); for(int i=1; i<n; i++) { int x,y; scanf("%d%d", &x, &y); kraw[x].push_back(y); kraw[y].push_back(x); } for(int i=1; i<=m; i++) scanf("%lld", &wagi[i]); if(n==2) { if(wagi[1]<wagi[2]) swap(wagi[1], wagi[2]); printf("%lld\n", wagi[1]-wagi[2]); return 0; } dfs(n); printf("%lld\n", wyn); } |