1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<cassert>
#include<iostream>
#include<algorithm>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<map>
#define REP(i,n) for(int i=0;i<(n);i++)
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define FORD(i,a,b) for(int i=(a);i>=(b);i--)
#define foreach(i,c) for(__typeof((c).begin())i=(c).begin();i!=(c).end();i++)
#define all(c) (c).begin(),(c).end()
#define scanf(...) scanf(__VA_ARGS__)?:0
#define eprintf(...) fprintf(stderr,__VA_ARGS__),fflush(stderr)
#define e1 first
#define e2 second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define infLL 1000000000000000023ll
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<ll,int> pli;
typedef pair<int,ll> pil;
int n,m,a,b,r[500001];
ll wyn;
pii prz[500001];
vector<int> v[500001];
bitset<500001> odw;
void dfswyn(int a)
{
	odw[a]=true;
	if (a<=m) return;
	vector<pii> s;
	vector<int> u,ilewo,iprawo;
	foreach(it,v[a]) if (!odw[*it]) dfswyn(*it),s.pb(prz[*it]);
	foreach(it,s) u.pb(it->e1),u.pb(it->e2);
	sort(all(u));
	u.erase(unique(all(u)),u.end());
	ilewo.resize(u.size());
	iprawo.resize(u.size());
	int p=u.size();
	foreach(it,s)
	{
		int lewy=lower_bound(all(u),it->e1)-u.begin(),prawy=lower_bound(all(u),it->e2)-u.begin();
		ilewo[prawy]++;
		iprawo[lewy]++;
	}
	FOR(i,1,p-1) ilewo[i]+=ilewo[i-1];
	FORD(i,p-2,0) iprawo[i]+=iprawo[i+1];
	ll wl=0,wp=0,w=infLL;
	foreach(it,s) wp+=it->e1-u[0];
	FOR(i,0,p-1)
	{
		w=min(w,wl+wp);
		if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]);
	}
	wl=0; wp=0; int lk=-1,pk=-1;
	foreach(it,s) wp+=it->e1-u[0];
	FOR(i,0,p-1)
	{
		if (wl+wp==w)
		{
			if (lk==-1) lk=i;
			pk=i;
		}
		if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]);
	}
	prz[a]=mp(u[lk],u[pk]);
	wyn+=w;
}
int main()
{
	scanf("%d%d",&n,&m);
	REP(i,n-1) scanf("%d%d",&a,&b),v[a].pb(b),v[b].pb(a);
	FOR(i,1,m) scanf("%d",&r[i]),prz[i]=mp(r[i],r[i]);
	if (n==m) assert(n==2),printf("%d",abs(r[1]-r[2])),exit(0);
	dfswyn(m+1);
	printf("%lld\n",wyn);
}