#include<cstdio> #include<algorithm> #define N 500010 int n,m,i,j,x,y,c,g[N],v[N<<1],nxt[N<<1],ed,d[N],l[N],r[N]; int del[N],G[N],V[N],NXT[N],h,t,q[N],f[N],a[N<<1]; long long ans,s,tmp,now; inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';} inline int abs(int x){return x>0?x:-x;} inline void add(int x,int y){d[x]++;v[++ed]=y;nxt[ed]=g[x];g[x]=ed;} inline void adde(int x,int y){f[y]=x;V[++ed]=y;NXT[ed]=G[x];G[x]=ed;} void dfs(int x){ if(!G[x])return; int i; for(i=G[x];i;i=NXT[i])dfs(V[i]); for(tmp=1LL<<60,j=c=s=m=0,i=G[x];i;i=NXT[i])a[m++]=l[V[i]],a[m++]=r[V[i]],c--,s+=l[V[i]]; for(std::sort(a,a+m),i=0;i<m;i++){ c++,s-=a[i],now=s+1LL*a[i]*c; if(now<tmp)l[x]=a[i],tmp=now; if(now==tmp)r[x]=a[i]; } ans+=tmp; } int main(){ read(n),read(m); for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x); for(i=1;i<=m;i++)read(l[i]),r[i]=l[i]; if(n==m){ for(i=1;i<=n;i++)for(j=g[i];j;j=nxt[j])ans+=abs(l[i]-l[v[j]]); printf("%lld",ans/2); return 0; } for(ed=0,i=h=1;i<=m;i++)del[q[++t]=i]=1; for(;h<=t;h=x+1){ for(i=h;i<=t;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])adde(v[j],q[i]); for(i=h,x=t;i<=x;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])if((--d[v[j]])<=1)del[q[++t]=v[j]]=1; } for(i=1;i<=n;i++)if(!f[i])adde(0,i); dfs(0); printf("%lld",ans); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | #include<cstdio> #include<algorithm> #define N 500010 int n,m,i,j,x,y,c,g[N],v[N<<1],nxt[N<<1],ed,d[N],l[N],r[N]; int del[N],G[N],V[N],NXT[N],h,t,q[N],f[N],a[N<<1]; long long ans,s,tmp,now; inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';} inline int abs(int x){return x>0?x:-x;} inline void add(int x,int y){d[x]++;v[++ed]=y;nxt[ed]=g[x];g[x]=ed;} inline void adde(int x,int y){f[y]=x;V[++ed]=y;NXT[ed]=G[x];G[x]=ed;} void dfs(int x){ if(!G[x])return; int i; for(i=G[x];i;i=NXT[i])dfs(V[i]); for(tmp=1LL<<60,j=c=s=m=0,i=G[x];i;i=NXT[i])a[m++]=l[V[i]],a[m++]=r[V[i]],c--,s+=l[V[i]]; for(std::sort(a,a+m),i=0;i<m;i++){ c++,s-=a[i],now=s+1LL*a[i]*c; if(now<tmp)l[x]=a[i],tmp=now; if(now==tmp)r[x]=a[i]; } ans+=tmp; } int main(){ read(n),read(m); for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x); for(i=1;i<=m;i++)read(l[i]),r[i]=l[i]; if(n==m){ for(i=1;i<=n;i++)for(j=g[i];j;j=nxt[j])ans+=abs(l[i]-l[v[j]]); printf("%lld",ans/2); return 0; } for(ed=0,i=h=1;i<=m;i++)del[q[++t]=i]=1; for(;h<=t;h=x+1){ for(i=h;i<=t;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])adde(v[j],q[i]); for(i=h,x=t;i<=x;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])if((--d[v[j]])<=1)del[q[++t]=v[j]]=1; } for(i=1;i<=n;i++)if(!f[i])adde(0,i); dfs(0); printf("%lld",ans); return 0; } |