1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//Przemysław Jakub Kozłowski
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#define FI first
#define SE second
#define MP make_pair
using namespace std;
typedef long long LL;
const int N = 500005;

int n,m;
vector<int> V[N];
int ojc[N];
pair<int, int> przedzial[N];
int wart[N];
vector<int> tmptab;
LL wyn;

void DFSA(int x)
{
    if(V[x].size() == 1) return;

    for(int i = 0;i < V[x].size();i++)
        if(V[x][i] != ojc[x])
        {
            ojc[V[x][i]] = x;
            DFSA(V[x][i]);
        }

    tmptab.clear();
    for(int i = 0;i < V[x].size();i++)
        if(V[x][i] != ojc[x])
        {
            tmptab.push_back(przedzial[V[x][i]].FI);
            tmptab.push_back(przedzial[V[x][i]].SE);
        }

    nth_element(tmptab.begin(), tmptab.begin()+tmptab.size()/2-1, tmptab.end());
    int l = tmptab[tmptab.size()/2-1];
    int p = tmptab[tmptab.size()/2];
    for(int i = tmptab.size()/2;i < tmptab.size();i++)
        p = min(p, tmptab[i]);
    przedzial[x] = MP(l, p);
}

void DFSB(int x)
{
    for(int i = 0;i < V[x].size();i++)
        if(V[x][i] != ojc[x])
        {
            if(przedzial[V[x][i]].SE < wart[x]) wart[V[x][i]] = przedzial[V[x][i]].SE;
            else if(wart[x] < przedzial[V[x][i]].FI) wart[V[x][i]] = przedzial[V[x][i]].FI;
            else wart[V[x][i]] = wart[x];
            wyn += abs(wart[x]-wart[V[x][i]]);
            DFSB(V[x][i]);
        }
}

int main()
{
    scanf("%d%d", &n, &m);
    if(n == m)
    {
        scanf("%*d%*d");
        int a,b;
        scanf("%d%d", &a, &b);
        printf("%d\n", abs(b-a));
        return 0;
    }

    for(int i = 1;i <= n-1;i++)
    {
        int a,b;
        scanf("%d%d", &a, &b);
        V[a].push_back(b);
        V[b].push_back(a);
    }

    for(int i = 1;i <= m;i++)
    {
        int a;
        scanf("%d", &a);
        przedzial[i] = MP(a,a);
    }

    ojc[n] = -1;
    DFSA(n);
    wart[n] = przedzial[n].FI;
    DFSB(n);

    printf("%lld\n", wyn);

    return 0;
}