1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<cstdio>
#include<algorithm>
#define N 500010
int n,m,i,j,x,y,c,g[N],v[N<<1],nxt[N<<1],ed,d[N],l[N],r[N];
int del[N],G[N],V[N],NXT[N],h,t,q[N],f[N],a[N<<1];
long long ans,s,tmp,now;
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline int abs(int x){return x>0?x:-x;}
inline void add(int x,int y){d[x]++;v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
inline void adde(int x,int y){f[y]=x;V[++ed]=y;NXT[ed]=G[x];G[x]=ed;}
void dfs(int x){
  if(!G[x])return;
  int i;
  for(i=G[x];i;i=NXT[i])dfs(V[i]);
  for(tmp=1LL<<60,j=c=s=m=0,i=G[x];i;i=NXT[i])a[m++]=l[V[i]],a[m++]=r[V[i]],c--,s+=l[V[i]];
  for(std::sort(a,a+m),i=0;i<m;i++){
    c++,s-=a[i],now=s+1LL*a[i]*c;
    if(now<tmp)l[x]=a[i],tmp=now;
    if(now==tmp)r[x]=a[i];
  }
  ans+=tmp;
}
int main(){
  read(n),read(m);
  for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
  for(i=1;i<=m;i++)read(l[i]),r[i]=l[i];
  if(n==m){
    for(i=1;i<=n;i++)for(j=g[i];j;j=nxt[j])ans+=abs(l[i]-l[v[j]]);
    printf("%lld",ans/2);
    return 0;
  }
  for(ed=0,i=h=1;i<=m;i++)del[q[++t]=i]=1;
  for(;h<=t;h=x+1){
    for(i=h;i<=t;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])adde(v[j],q[i]);
    for(i=h,x=t;i<=x;i++)for(j=g[q[i]];j;j=nxt[j])if(!del[v[j]])if((--d[v[j]])<=1)del[q[++t]=v[j]]=1;
  }
  for(i=1;i<=n;i++)if(!f[i])adde(0,i);
  dfs(0);
  printf("%lld",ans);
  return 0;
}