1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
vector <int> kr[500005];
int n,m,rozstaw[500005];
pair <int,int> postulat[500005];
pair <int,int> dfs(int start, int pop){
    if (kr[start].size() == 1){
        return make_pair(rozstaw[start], rozstaw[start]);
    }
    vector < pair <int,int> > prz;
    vector < pair <int, pair <int, int> > > tab;
    int op = 0, after_me = 0;
    for (int i = 0; i < kr[start].size(); i++){
        if (kr[start][i] != pop){
            prz.push_back(dfs(kr[start][i], start));
            tab.push_back(make_pair(prz[prz.size()-1].first, make_pair(0, prz.size()-1)));
            tab.push_back(make_pair(prz[prz.size()-1].second, make_pair(1, prz.size()-1)));
        }
    }
    sort(tab.begin(), tab.end());
    pair <int,int> ret = make_pair(1000005, -1);
    for (int i = 0; i < tab.size(); i++){
        bool we_are_before = false, we_are_after = false;
        if (after_me < prz.size() - after_me - op){
            we_are_before = true;
        }
        if (after_me == prz.size() - after_me - op){
            ret.first = min(ret.first, tab[i].first);
            ret.second = max(ret.second, tab[i].first);
        }
        if (tab[i].second.first == 0){
            op++;
        }
        else {
            op--;
            after_me++;
        }
        if (after_me > prz.size() - after_me - op){
            we_are_after = true;
        }
        if (after_me == prz.size() - after_me - op){
            ret.first = min(ret.first, tab[i].first);
            ret.second = max(ret.second, tab[i].first);
        }
        if ((we_are_before) && (we_are_after)){
            ret = make_pair(tab[i].first, tab[i].first);
        }
    }
    postulat[start] = ret;
    return ret;
}
long long dfs2(int start, int pop, int value){
    if (kr[start].size() == 1){
        return 0;
    }
    long long ret = 0;
    for (int i = 0; i < kr[start].size(); i++){
        if (kr[start][i] != pop){
            if (value > postulat[ kr[start][i] ].second){
                ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].second);
                ret += value - postulat[ kr[start][i] ].second;
            }
            else if (value < postulat[ kr[start][i] ].first){
                ret += dfs2(kr[start][i], start, postulat[ kr[start][i] ].first);
                ret += postulat[ kr[start][i] ].first - value;
            }
            else {
                ret += dfs2(kr[start][i], start, value);
            }
        }
    }
    return ret;
}
int main(){
    scanf("%d%d",&n,&m);
    for (int i = 0; i < n-1; i++){
        int a,b;
        scanf("%d%d",&a,&b);
        kr[a].push_back(b);
        kr[b].push_back(a);
    }
    for (int i = 1; i <= m; i++){
        scanf("%d", &rozstaw[i]);
        postulat[i] = make_pair(rozstaw[i], rozstaw[i]);
    }
    if ((n == 2) && (m == 2)){
        if (rozstaw[1] > rozstaw[2])
            printf("%d\n", rozstaw[1] - rozstaw[2]);
        else
            printf("%d\n", rozstaw[2] - rozstaw[1]);
        return 0;
    }
    dfs(n, -1);
    printf("%lld\n", dfs2(n,-1, postulat[n].first));
}