1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <cstdio>
#include <iostream>
#include <set>
#include <map>
#include <algorithm>
#include <iterator>
#include <string>
#include <vector>
#include <cmath>
#include <iomanip>
#include <deque>
#include <cassert>
using namespace std;
typedef long long LL;
typedef long double LD;
typedef pair<int,int> PII;
#define MP make_pair
#define FOR(v,p,k) for(int v=p;v<=k;++v)
#define FORD(v,p,k) for(int v=p;v>=k;--v)
#define REP(i,n) for(int i=0;i<(n);++i)
#define VAR(v,i) __typeof(i) v=(i)
#define FOREACH(i,c) for(VAR(i,(c).begin());i!=(c).end();++i)
#define PB push_back
#define ST first
#define ND second
#define SIZE(x) (int)x.size()
#define ALL(c) c.begin(),c.end()
using namespace std;
#define int long long

int n,m;
vector<int> sas[1000005];
bool odw[1000005];
int wart[1000005];
int suma = 0;
pair<int,int> dfs(int w){
    odw[w] = true;
    vector<pair<int,int> > sons;
    REP(i, sas[w].size()){
        int idx = sas[w][i];
        if(odw[idx]) continue;
        if(idx < m){
            sons.PB(MP(wart[idx],wart[idx]));
        } else{
            sons.PB(dfs(idx));
        }
    }
    vector<int> deepshit;
    REP(i, sons.size()){
        deepshit.PB(sons[i].ST);
        deepshit.PB(sons[i].ND);
    }

    sort(ALL(deepshit));


    int med1 = deepshit[(deepshit.size() + 1) / 2];
    int med2 = deepshit[(deepshit.size() + 2) / 2];
    REP(i,sons.size()){
        if(med1 >= sons[i].ST && med1 <= sons[i].ND){

        } else{
            suma += min(abs(med1 - sons[i].ST), abs(med1 - sons[i].ND));
        }
    }
    return MP(med1,med2);
}

#undef int
int main() {
#define int long long
    cin>>n>>m;
    REP(i, n - 1){
        int t1,t2;
        cin>>t1>>t2;
        t1--;
        t2--;
        if(t1 > t2){
            swap(t1,t2);
        }
        sas[t1].PB(t2);
        sas[t2].PB(t1);
    }
    REP(i, m){
        cin>>wart[i];
    }
    dfs(n - 1);
    cout<<suma<<endl;
    return 0;
}