1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <cstdio>
#include <vector>
#include <cstdlib>
#include <algorithm>
using std::vector;
using std::sort;
#define ll long long

const int maxN = 5e5 + 10;
int from[maxN];
int to[maxN];
bool B[maxN];
vector<int>nbh[maxN];
int n, m;

bool cmp(int a, int b) {
    if(abs(a) == abs(b))
        return a > b;
    return abs(a) < abs(b);
}

long long res;
int vis = 0;
void dfs(int v) {
    vis++;
    if(nbh[v].size() == 1)
        return;
    B[v] = true;
    vector<int>val;
    for(int a: nbh[v])
        if(!B[a]) {
            dfs(a);
            val.push_back(from[a]);
            val.push_back(-to[a]);
        }

    sort(val.begin(), val.end(), cmp);
    int beg = 0, end = 0;
    long long mn = 0, sum = 0;
    for(int i = val.size() - 2; i >= 0; --i) {
        sum += (abs(val[i + 1]) - abs(val[i])) * (ll) end;
        if(val[i] > 0)
            end++;
    }

    mn = sum;
    from[v] = to[v] = abs(val[0]);
    for(int i = 1; i < (int)val.size(); ++i) {
        if(val[i] > 0) end--;
        if(val[i] < 0) beg++;
//         printf(" %d(%lld)[%d %d] ", val[i], sum, beg, end);
//         printf("  --<%lld>--> ", (abs(val[i + 1]) - abs(val[i])) * (ll) (beg - end));
        sum += (abs(val[i]) - abs(val[i - 1])) * (ll) (beg - end);

        if(sum < mn) {
            mn = sum;
            from[v] = abs(val[i]);
        } else if(sum == mn)
            to[v] = abs(val[i]);
    }
//     printf("\nfor %d mn is %lld [%d %d]\n", v, mn, from[v], to[v]);
    res += mn;
    B[v] = false;
}

int main() {
    scanf("%d%d", &n, &m);
    for(int a, b, i = 1; i < n; ++i) {
        scanf("%d%d", &a, &b);
        nbh[a].push_back(b);
        nbh[b].push_back(a);
    }
    for(int i = 1; i <= m; ++i) {
        scanf("%d", from + i);
        to[i] = from[i];
    }
    if(n == m && m == 2) {
        res = abs(from[1] - from[2]);
    } else {
        dfs(n);
    }
    printf("%lld\n", res);
    return 0;
}