1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <bits/stdc++.h>
#define REP(a,b) for(int a=0; a<(b); ++a)
#define FWD(a,b,c) for(int a=(b); a<(c); ++a)
#define FWDS(a,b,c,d) for(int a=(b); a<(c); a+=d)
#define BCK(a,b,c) for(int a=(b); a>(c); --a)
#define ALL(a) (a).begin(), (a).end()
#define SIZE(a) ((int)(a).size())
#define SQ(a) ((a)*(a))
#define VAR(x) #x ": " << x << " "
#define popcount __builtin_popcount
#define popcountll __builtin_popcountll
#define gcd __gcd
#define x first
#define y second
#define st first
#define nd second
#define pb push_back

using namespace std;

template<typename T> ostream& operator<<(ostream &out, const vector<T> &v){ out << "{"; for(const T &a : v) out << a << ", "; out << "}"; return out; }
template<typename S, typename T> ostream& operator<<(ostream &out, const pair<S,T> &p){ out << "(" << p.st << ", " << p.nd << ")"; return out; }

typedef long long int64;
typedef pair<int, int> PII;
typedef pair<int64, int64> PLL;
typedef long double K;
typedef vector<int> VI;

const int dx[] = {0,0,-1,1}; //1,1,-1,1};
const int dy[] = {-1,1,0,0}; //1,-1,1,-1};

const int64 INF = SQ(1000LL * 1000 * 1000);

int n, m;
int par[500010];
vector<int> edges[500010];
PII R[500010];

int64 cost;

vector<PII> intervals;
vector<PII> events;
	
void dfs(int u){
	if(u <= m) return;

	for(int v : edges[u])
		if(!par[v]){
			par[v] = u;
			dfs(v);
		}

	intervals.clear();
	events.clear();

	for(int v : edges[u])
		if(par[v] == u){
			intervals.push_back(R[v]);
			events.push_back(PII(R[v].st, 0));
			events.push_back(PII(R[v].nd, 1));
		}

	int lo = 500010, hi = -1;
	sort(events.begin(), events.end());
	int left = 0, right = SIZE(intervals);
	for(PII e : events){
		if(left == right)
			hi = max(hi, e.st);
		
		if(e.nd == 0){
			--right;
		}else{
			++left;
		}
		
		if(left == right)
			lo = min(lo, e.st);
	}

	for(PII in : intervals){
		if(in.nd < lo)
			cost += lo - in.nd;
		else if(in.st > lo)
			cost += in.st - lo;
	}

	R[u] = PII(lo, hi);
	return;
}

int main(){
	scanf("%d %d", &n, &m);
	FWD(i,1,n){
		int a, b;
		scanf("%d %d", &a, &b);
		edges[a].push_back(b);
		edges[b].push_back(a);
	}
	FWD(i,1,m+1){
		scanf("%d", &R[i].st);
		R[i].nd = R[i].st;
	}
	if(m == n){
		printf("%d\n", abs(R[1].st - R[2].st));
	}else{
		par[m+1] = -1;
		dfs(m+1);
		printf("%lld\n", cost);
	}
	return 0;
}