1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

vector <int> v[500010];
int p[500010];
long long k[500010];
int d[500010];
int tab[1000010];

void ABC(int x) {
	d[x] = 1;
	for(int i = 0; i < v[x].size(); i ++) {
		if(!d[v[x][i]]) {
			ABC(v[x][i]);
		}
	}  
	int n = 0;
	for(int i = 0; i < v[x].size(); i ++) {
		if(d[v[x][i]] == 2) {
			tab[n] = p[v[x][i]];
			n ++;
			tab[n] = k[v[x][i]];
			n ++;
		}
	}  
	sort(tab, tab + n);
	n /= 2;
	if(v[x].size() > 1) {
		p[x] = tab[n - 1];
		k[x] = tab[n];
	}
	d[x] = 2;
}

long long DEF(int x, int px) {
	d[x] = 1;
	if(px) {
		if(p[px] < p[x]) {
			k[x] = p[x];
		}
		else if(k[px] > k[x]) {
			p[x] = k[x];
		}
		else {
			p[x] = p[px];
			k[x] = p[x];
		}
	}
	long long wyn = 0;
	for(int i = 0; i < v[x].size(); i ++) {
		if(d[v[x][i]] == 2) {
			wyn += DEF(v[x][i], x);
			wyn += abs(k[x] - k[v[x][i]]);
		}
	}
	return wyn;
}

int main() {
	int n, m, a, b;
	scanf("%d%d", &n, &m);

	for(int i = 1; i < n; i ++) {
		scanf("%d%d", &a, &b);
		v[a].push_back(b);
		v[b].push_back(a);
	}
	for(int i = 1; i <= m; i ++) {
		scanf("%lld", &p[i]);
		k[i] = p[i];
	}
	ABC(1);
	printf("%lld\n", DEF(1, 0));
}