#include<bits/stdc++.h> #define FOR(i,s,e) for(int i=(s);i<=(e);i++) #define FORD(i,s,e) for(int i=(s);i>=(e);i--) #define ALL(k) (k).begin(),(k).end() #define e1 first #define e2 second #define MP make_pair #define PB push_back #define EB emplace_back using namespace std; typedef long long LL; typedef pair<int,int> PII; typedef pair<LL,LL> PLL; typedef pair<int,PII> PIP; typedef pair<PLL,PLL> PPP; const int MAXN=500111; vector<int> kraw[MAXN]; int ojc[MAXN]; int va[MAXN],vb[MAXN]; LL dfs(int v){ vector<int> V; LL ans=0; if(kraw[v].size()==1) return 0; for(auto b:kraw[v]){ if(b==ojc[v]) continue; ojc[b]=v; ans+=dfs(b); V.PB(va[b]);V.PB(vb[b]); } sort(ALL(V)); int vs=((int)V.size())/2; va[v]=V[vs-1],vb[v]=V[vs]; //printf("%d %d %d %lld\n",v,va[v],vb[v],ans); //for(auto it:V) printf("%d ",it); //puts(""); for(auto b:kraw[v]){ if(b==ojc[v]) continue; LL r1=max(va[v]-va[b],va[b]-va[v]); LL r2=max(va[v]-vb[b],vb[b]-va[v]); //printf("%d;; %d %d %d;; %lld %lld\n",b,va[b],va[v],vb[b],r1,r2); if(va[b]<=va[v]&&va[v]<=vb[b]) continue; ans+=min(r1,r2); } //printf("%lld\n",ans); return ans; } main(){ int n,m;scanf("%d%d",&n,&m); FOR(i,2,n){ int a,b;scanf("%d%d",&a,&b); kraw[a].PB(b); kraw[b].PB(a); } FOR(i,1,m){ int x;scanf("%d",&x); va[i]=vb[i]=x; } if(n==2&&m==2){ printf("%d\n",max(va[1]-va[2],va[2]-va[1])); return 0; } printf("%lld\n",dfs(n)); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | #include<bits/stdc++.h> #define FOR(i,s,e) for(int i=(s);i<=(e);i++) #define FORD(i,s,e) for(int i=(s);i>=(e);i--) #define ALL(k) (k).begin(),(k).end() #define e1 first #define e2 second #define MP make_pair #define PB push_back #define EB emplace_back using namespace std; typedef long long LL; typedef pair<int,int> PII; typedef pair<LL,LL> PLL; typedef pair<int,PII> PIP; typedef pair<PLL,PLL> PPP; const int MAXN=500111; vector<int> kraw[MAXN]; int ojc[MAXN]; int va[MAXN],vb[MAXN]; LL dfs(int v){ vector<int> V; LL ans=0; if(kraw[v].size()==1) return 0; for(auto b:kraw[v]){ if(b==ojc[v]) continue; ojc[b]=v; ans+=dfs(b); V.PB(va[b]);V.PB(vb[b]); } sort(ALL(V)); int vs=((int)V.size())/2; va[v]=V[vs-1],vb[v]=V[vs]; //printf("%d %d %d %lld\n",v,va[v],vb[v],ans); //for(auto it:V) printf("%d ",it); //puts(""); for(auto b:kraw[v]){ if(b==ojc[v]) continue; LL r1=max(va[v]-va[b],va[b]-va[v]); LL r2=max(va[v]-vb[b],vb[b]-va[v]); //printf("%d;; %d %d %d;; %lld %lld\n",b,va[b],va[v],vb[b],r1,r2); if(va[b]<=va[v]&&va[v]<=vb[b]) continue; ans+=min(r1,r2); } //printf("%lld\n",ans); return ans; } main(){ int n,m;scanf("%d%d",&n,&m); FOR(i,2,n){ int a,b;scanf("%d%d",&a,&b); kraw[a].PB(b); kraw[b].PB(a); } FOR(i,1,m){ int x;scanf("%d",&x); va[i]=vb[i]=x; } if(n==2&&m==2){ printf("%d\n",max(va[1]-va[2],va[2]-va[1])); return 0; } printf("%lld\n",dfs(n)); } |