#include <cstdio> #include <vector> #include <algorithm> using namespace std; typedef pair<int,int> pii; int n,m,x,y,i,a[500500],le[500500],ri[500500]; vector<int> g[500500]; long long res; void dfs(int i, int p) { vector<pii> all; if (i<=m) { le[i]=a[i]; ri[i]=a[i]; return; } int x=0,y=0; for (int j=0; j<g[i].size(); j++) { int k=g[i][j]; if (k==p) continue; dfs(k,i); all.push_back(make_pair(le[k],0)); all.push_back(make_pair(ri[k],1)); y++; } sort(all.begin(),all.end()); for (int ii=0, j=0; ii<all.size(); ii=j) { for (j=ii; j<all.size() && all[ii].first==all[j].first; j++) if (all[j].second) x++; else y--; if (x==y && j<all.size()) { le[i]=all[ii].first; ri[i]=all[j].first; break; } if (x>=y) { le[i]=ri[i]=all[ii].first; break; } } for (int j=0; j<g[i].size(); j++) { int k=g[i][j]; if (k==p) continue; if (ri[k]<le[i]) res+=le[i]-ri[k]; if (le[k]>le[i]) res+=le[k]-le[i]; } } int main() { scanf("%d%d",&n,&m); for (i=1; i<n; i++) { scanf("%d%d",&x,&y); g[x].push_back(y); g[y].push_back(x); } for (i=1; i<=m; i++) scanf("%d",&a[i]); if (n==2) { printf("%d\n",abs(a[1]-a[2])); return 0; } dfs(n,0); printf("%lld\n",res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | #include <cstdio> #include <vector> #include <algorithm> using namespace std; typedef pair<int,int> pii; int n,m,x,y,i,a[500500],le[500500],ri[500500]; vector<int> g[500500]; long long res; void dfs(int i, int p) { vector<pii> all; if (i<=m) { le[i]=a[i]; ri[i]=a[i]; return; } int x=0,y=0; for (int j=0; j<g[i].size(); j++) { int k=g[i][j]; if (k==p) continue; dfs(k,i); all.push_back(make_pair(le[k],0)); all.push_back(make_pair(ri[k],1)); y++; } sort(all.begin(),all.end()); for (int ii=0, j=0; ii<all.size(); ii=j) { for (j=ii; j<all.size() && all[ii].first==all[j].first; j++) if (all[j].second) x++; else y--; if (x==y && j<all.size()) { le[i]=all[ii].first; ri[i]=all[j].first; break; } if (x>=y) { le[i]=ri[i]=all[ii].first; break; } } for (int j=0; j<g[i].size(); j++) { int k=g[i][j]; if (k==p) continue; if (ri[k]<le[i]) res+=le[i]-ri[k]; if (le[k]>le[i]) res+=le[k]-le[i]; } } int main() { scanf("%d%d",&n,&m); for (i=1; i<n; i++) { scanf("%d%d",&x,&y); g[x].push_back(y); g[y].push_back(x); } for (i=1; i<=m; i++) scanf("%d",&a[i]); if (n==2) { printf("%d\n",abs(a[1]-a[2])); return 0; } dfs(n,0); printf("%lld\n",res); return 0; } |