#include<bits/stdc++.h>
#define FOR(i,s,e) for(int i=(s);i<=(e);i++)
#define FORD(i,s,e) for(int i=(s);i>=(e);i--)
#define ALL(k) (k).begin(),(k).end()
#define e1 first
#define e2 second
#define MP make_pair
#define PB push_back
#define EB emplace_back
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
typedef pair<LL,LL> PLL;
typedef pair<int,PII> PIP;
typedef pair<PLL,PLL> PPP;
const int MAXN=500111;
vector<int> kraw[MAXN];
int ojc[MAXN];
int va[MAXN],vb[MAXN];
LL dfs(int v){
vector<int> V;
LL ans=0;
if(kraw[v].size()==1) return 0;
for(auto b:kraw[v]){
if(b==ojc[v]) continue;
ojc[b]=v;
ans+=dfs(b);
V.PB(va[b]);V.PB(vb[b]);
}
sort(ALL(V));
int vs=((int)V.size())/2;
va[v]=V[vs-1],vb[v]=V[vs];
//printf("%d %d %d %lld\n",v,va[v],vb[v],ans);
//for(auto it:V) printf("%d ",it);
//puts("");
for(auto b:kraw[v]){
if(b==ojc[v]) continue;
LL r1=max(va[v]-va[b],va[b]-va[v]);
LL r2=max(va[v]-vb[b],vb[b]-va[v]);
//printf("%d;; %d %d %d;; %lld %lld\n",b,va[b],va[v],vb[b],r1,r2);
if(va[b]<=va[v]&&va[v]<=vb[b]) continue;
ans+=min(r1,r2);
}
//printf("%lld\n",ans);
return ans;
}
main(){
int n,m;scanf("%d%d",&n,&m);
FOR(i,2,n){
int a,b;scanf("%d%d",&a,&b);
kraw[a].PB(b);
kraw[b].PB(a);
}
FOR(i,1,m){
int x;scanf("%d",&x);
va[i]=vb[i]=x;
}
if(n==2&&m==2){
printf("%d\n",max(va[1]-va[2],va[2]-va[1]));
return 0;
}
printf("%lld\n",dfs(n));
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | #include<bits/stdc++.h> #define FOR(i,s,e) for(int i=(s);i<=(e);i++) #define FORD(i,s,e) for(int i=(s);i>=(e);i--) #define ALL(k) (k).begin(),(k).end() #define e1 first #define e2 second #define MP make_pair #define PB push_back #define EB emplace_back using namespace std; typedef long long LL; typedef pair<int,int> PII; typedef pair<LL,LL> PLL; typedef pair<int,PII> PIP; typedef pair<PLL,PLL> PPP; const int MAXN=500111; vector<int> kraw[MAXN]; int ojc[MAXN]; int va[MAXN],vb[MAXN]; LL dfs(int v){ vector<int> V; LL ans=0; if(kraw[v].size()==1) return 0; for(auto b:kraw[v]){ if(b==ojc[v]) continue; ojc[b]=v; ans+=dfs(b); V.PB(va[b]);V.PB(vb[b]); } sort(ALL(V)); int vs=((int)V.size())/2; va[v]=V[vs-1],vb[v]=V[vs]; //printf("%d %d %d %lld\n",v,va[v],vb[v],ans); //for(auto it:V) printf("%d ",it); //puts(""); for(auto b:kraw[v]){ if(b==ojc[v]) continue; LL r1=max(va[v]-va[b],va[b]-va[v]); LL r2=max(va[v]-vb[b],vb[b]-va[v]); //printf("%d;; %d %d %d;; %lld %lld\n",b,va[b],va[v],vb[b],r1,r2); if(va[b]<=va[v]&&va[v]<=vb[b]) continue; ans+=min(r1,r2); } //printf("%lld\n",ans); return ans; } main(){ int n,m;scanf("%d%d",&n,&m); FOR(i,2,n){ int a,b;scanf("%d%d",&a,&b); kraw[a].PB(b); kraw[b].PB(a); } FOR(i,1,m){ int x;scanf("%d",&x); va[i]=vb[i]=x; } if(n==2&&m==2){ printf("%d\n",max(va[1]-va[2],va[2]-va[1])); return 0; } printf("%lld\n",dfs(n)); } |
English