1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <cassert>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

#define MAXN 500007

//#define DEBUGPRINT
//#define DEBUGASSERT

#ifdef DEBUGPRINT
#define DEBPRINT(...) fprintf(stderr, __VA_ARGS__)
#else
#define DEBPRINT(...)
#endif
#ifdef DEBUGASSERT
#define ASSERT(x) assert(x)
#else
#define ASSERT(x)
#endif

typedef long long slng;

int N, M;
vector<int> E[MAXN];
int minr[MAXN], maxr[MAXN];

enum typ {TYPEMIN=0, TYPEMAX=1};
struct pp {
	int r;
	typ type;
};
bool operator<(const pp& lhs, const pp& rhs) {
	if (lhs.r != rhs.r)
		return lhs.r < rhs.r;
	return lhs.type < rhs.type;
}

slng solve(int n, int parent) {
	int i;
	slng ret_child, ret1, ret2;
	vector<pp> p;
	int left,in,right;
	
	if (n < M)
		return 0; // leaf
	
	ret_child = 0;
	for (i = 0; i < E[n].size(); i++) {
		int k = E[n][i];
		if (k != parent) {
			ret_child += solve(k, n);
		}
	}
	for (i = 0; i < E[n].size(); i++) {
		int k = E[n][i];
		if (k != parent) {
			p.push_back({minr[k], TYPEMIN});
			p.push_back({maxr[k], TYPEMAX});
		}
	}
	sort(p.begin(), p.end());
	left = 0; in = 0; right = p.size()/2;
	for (i = 0; i < p.size(); i++) {
		DEBPRINT("(%d %d %d)[%d %d]", left, in, right, p[i].r, p[i].type);
		if (p[i].type == TYPEMIN) {
			right--;
			in++;
		} else if (p[i].type == TYPEMAX) {
			in--;
			left++;
		}
		if (left == right) {
			minr[n] = p[i].r;
		}
		if (left == right + 1) {
			maxr[n] = p[i].r;
		}
	}
	DEBPRINT("\nminr,maxr[%d] = [%d, %d]\n", n, minr[n], maxr[n]);
	
	ret1 = 0; ret2 = 0;
	for (i = 0; i < E[n].size(); i++) {
		int k = E[n][i];
		if (k != parent) {
			if (maxr[k] < minr[n]) {
				ret1 += minr[n] - maxr[k];
			} else if (minr[k] > minr[n]) {
				ret1 += minr[k] - minr[n];
			}
#ifdef DEBUGASSERT
			if (maxr[k] < maxr[n]) {
				ret2 += maxr[n] - maxr[k];
			} else if (minr[k] > minr[n]) {
				ret2 += minr[k] - maxr[n];
			}
#endif
		}
	}
	
	ASSERT(ret1 == ret2);
	return ret_child + ret1;
}

int main() {
	int i;
	slng ret;
	
	scanf("%d%d", &N, &M);
	for (i = 0; i < N-1; i++) {
		int a,b;
		scanf("%d%d", &a, &b);
		E[a-1].push_back(b-1);
		E[b-1].push_back(a-1);
	}
	for (i = 0; i < M; i++) {
		scanf("%d", &minr[i]);
		maxr[i] = minr[i];
	}
	
	printf("%lld\n", solve(M, -1));
	
	return 0;
}