1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>

using namespace std;

const int N = 500005;

int n, m, a, b;
int value[N], prz[N];
int intervalBegin[N], intervalEnd[N];
vector<int> V[N];
long long ans;

void DFS(int w) {
    
    if (value[w] != 0) {
        intervalBegin[w] = value[w];
        intervalEnd[w] = value[w];
        return;
    }
    
    vector<pair<int, int> > events;
    long long sum = 0;
    for (int i = 0; i < V[w].size(); i++) {
        int u = V[w][i];
        if (prz[w] != u) {
            prz[u] = w;
            DFS(u);
            events.push_back({intervalBegin[u], 0});
            events.push_back({intervalEnd[u], 1});
            sum += intervalBegin[u] - 1;
        }
    }
    
    sort(events.begin(), events.end());
    
    int currentRight = events.size() / 2;
    int currentLeft = 0;
    int bestIntervalBegin = 1;
    int bestIntervalEnd = 1;
    long long lastPosition = 1;
    long long bestSum = sum;
    
    for (int i = 0; i < events.size(); i++) {
        int position = events[i].first;
        
        sum += (position - lastPosition) * (currentLeft - currentRight);
        lastPosition = position;
        if (bestSum > sum) {
            bestSum = sum;
            bestIntervalBegin = position;
        } else if (bestSum == sum) {
            bestIntervalEnd = position;
        }
        int what = events[i].second;      
        if (what == 0) {
            currentRight--;
        } else {
            currentLeft++;
        }
        if (i == events.size() - 1 || position != events[i + 1].first) {
            if (sum + currentLeft - currentRight < bestSum) {
                bestSum = sum;
                bestIntervalBegin = position;
            }
        }
    }
    
    intervalBegin[w] = bestIntervalBegin;
    intervalEnd[w] = bestIntervalEnd;
    ans += bestSum;
}

int main() {
    scanf("%d %d", &n, &m);
    
    for (int i = 1; i < n; i++) {
        scanf("%d %d", &a, &b);
        V[a].push_back(b);
        V[b].push_back(a);
    }
    
    for (int i = 1; i <= m; i++) {
        scanf("%d", &value[i]);
    }
    
    if (n == m) {
        printf("%d\n", abs(value[1] - value[2]));
        return 0;
    }
    
    DFS(m + 1);
    
    printf("%lld\n", ans);
    
    return 0;
}