1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <cstdio>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;

typedef long long ll;
const ll inf = 1e12;
ll ans;

const int N = 500001;
vector<int> g[N];
int r[N];

pair<int, int> dfs(int v, int prev)
{
    if(r[v]) return make_pair(r[v], r[v]);
    vector<pair<int, bool>> vec;
    ll current = 0;
    int point = 0;
    for(int u: g[v])
	if(u != prev)
	{
	    auto p = dfs(u, v);
	    vec.emplace_back(p.first, 0);
	    vec.emplace_back(p.second, 1);
	    current += p.first;
	}
    sort(vec.begin(), vec.end());
    ll ans = current;
    int first = 0, last = 0, pre = 0, past = vec.size() / 2;
    for(auto p: vec)
    {
	current += (ll)(pre - past) * (p.first - point);
	point = p.first;
	if(current < ans)
	{
	    ans = current;
	    first = point;
	}
	if(current == ans) last = point;
	if(p.second) pre++;
	else past--;
    }
    ::ans += ans;
    return make_pair(first, last);
}

int main()
{
    int n, m;
    scanf("%d %d", &n, &m);
    for(int i = 1; i < n; i++)
    {
	int a, b;
	scanf("%d %d", &a, &b);
	g[a].push_back(b);
	g[b].push_back(a);
    }
    for(int i = 1; i <= m; i++)
	scanf("%d", r + i);
    if(n == m)
	printf("%d\n", abs(r[1] - r[2]));
    else
    {
	dfs(m + 1, 0);
	printf("%lld\n", ans);
    }
}