1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<iostream>
#include<algorithm>
#include<vector>
#define st first
#define nd second
#define mp make_pair
#define PII pair<int,int>
#define pb push_back
#define LL long long
#define INF 1000000000000000000LL
using namespace std;
PII mediana(vector<PII >v){
	if(v.size()==1)return v[0];
	PII res=mp(-1,-1);
	vector<PII>ev;
	vector<LL>wyn;
	LL akt=0,minim=INF;
	for(int i=0;i<v.size();i++){
		ev.pb(mp(v[i].st,1));
		ev.pb(mp(v[i].nd,-1));
	}
	for(int i=0;i<ev.size();i++){
		if(ev[i].nd==1)akt+=ev[i].st-ev[0].st;
	}
	sort(ev.begin(),ev.end());
	int lewo=0,prawo=v.size();
	for(int i=0;i<ev.size();i++){
		if(ev[i].nd==1){
			prawo--;	
		}
		else{
			lewo++;	
		}
		wyn.pb(akt);
		if(akt<minim)minim=akt;
		int d=0;
		if(i+1<ev.size())d=ev[i+1].st-ev[i].st;
		akt-=(LL)d*prawo;
		akt+=(LL)d*lewo;
	}
	for(int i=0;i<ev.size();i++){
		if(wyn[i]==minim){
			if(res.st==-1)res.st=ev[i].st;
			res.nd=ev[i].st;
		}
	}
	return res;
}
vector<int>g[500000];
PII przedz[500000];
bool odw[500000];
void dfs(int x){
	odw[x]=1;
	vector<PII>syn;
	for(int i=0;i<g[x].size();i++){
		if(odw[g[x][i]])continue;
		dfs(g[x][i]);
		syn.pb(przedz[g[x][i]]);
	}
	if(syn.size()>0)przedz[x]=mediana(syn);
}
LL count(int x,int wyb){
	odw[x]=1;
	LL res=0;
	for(int i=0;i<g[x].size();i++){
		int w=g[x][i];
		if(odw[w])continue;
		int co=wyb;
		if(co>przedz[w].nd)co=przedz[w].nd;
		if(co<przedz[w].st)co=przedz[w].st;
		res+=abs(wyb-co);
		res+=count(w,co);
	}
	return res;
}
main(){
	ios_base::sync_with_stdio(0);
	int n,m;
	cin>>n>>m;
	for(int i=0;i<n-1;i++){
		int a,b;
		cin>>a>>b;
		a--;b--;
		g[a].pb(b);
		g[b].pb(a);
	}
	for(int i=0;i<m;i++){
		int a;
		cin>>a;
		przedz[i]=mp(a,a);
	}
	if(n==2){
		cout<<abs(przedz[0].st-przedz[1].st)<<"\n";
		return 0;
	}
	dfs(n-1);
	for(int i=0;i<n;i++)odw[i]=0;
	cout<<count(n-1,przedz[n-1].st)<<"\n";
}