1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include<cstdio>
#include<algorithm>
#include<vector>
#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<iostream>
#include<string>
using namespace std;
#define F first
#define S second
#define MP make_pair
#define PB push_back
#define LL long long
#define PII pair<int, int>
#define PLL pair<LL, LL>

int n, k;
LL res;
vector<int> V[1000005];
PII T[1000005];
bool vis[1000005];
vector<PII> X;

void DFS(int v, int prev)
{
	vis[v]=1;
	if((int)V[v].size()==1)
		return;
	for(int i=0; i<(int)V[v].size(); i++)
		if(!vis[V[v][i]])
			DFS(V[v][i], v);
	X.clear();
	LL r=1000000000000000000LL, sum=0;
	int le=0, mid=0, ri=0, pre=0;
	T[v].F=1000000;
	//printf("%d:\n", v);
	for(int i=0; i<(int)V[v].size(); i++)
	{
		if(V[v][i]!=prev)
		{
			sum+=T[V[v][i]].F;
			ri++;
			//printf("%d %d\n", T[V[v][i]].F, T[V[v][i]].S);
			X.PB(MP(T[V[v][i]].F, 0));
			X.PB(MP(T[V[v][i]].S, 1));
		}
	}
	sort(X.begin(), X.end());
	for(int i=0; i<(int)X.size(); i++)
	{
		
		sum+=(LL)le*(X[i].F-pre);
		sum-=(LL)ri*(X[i].F-pre);
		pre=X[i].F;
		if(le==ri)
		{
			T[v].F=min(T[v].F, X[i].F);
			T[v].S=max(T[v].S, X[i].F);
			r=min(r, sum);
		}
		if(X[i].F==0)
		{
			ri--;
			mid++;
		}
		else
		{
			mid--;
			le++;
		}
		if(le==ri)
		{
			T[v].F=min(T[v].F, X[i].F);
			T[v].S=max(T[v].S, X[i].F);
			r=min(r, sum);
		}
	}
	//printf("%d %lld\n", v, r);
	res+=r;
}

int main()
{
	//ios_base::sync_with_stdio(0);
	scanf("%d%d", &n, &k);
	for(int i=1; i<n; i++)
	{
		int a, b;
		scanf("%d%d", &a, &b);
		V[a].PB(b);
		V[b].PB(a);
	}
	for(int i=1; i<=k; i++)
	{
		int a;
		scanf("%d", &a);
		T[i]=MP(a, a);
	}
	if(n==2)
	{
		if(k==1)
			printf("0\n");
		else
			printf("%d\n", abs(T[2].F-T[1].F));
		return 0;
	}
	DFS(n, n);
	printf("%lld\n", res);
	return 0;
}