1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

class sect_end{
    public:
    sect_end(){}
    sect_end( int _p, int _v ) { p=_p; v=_v; }
    int p;
    int v;
};

bool operator<( const sect_end &s1, const sect_end &s2 ) {
    if ( s1.p != s2.p ) return s1.p < s2.p;
    return s1.v < s2.v;
}

vector<int> KR[500005];
int T[500005];
int ODW[500005];
long long w;
int n,m;

pair<int,int> DFS( int v ) {
    if ( v<m ) return make_pair( T[v], T[v] );
    ODW[v]=1;
    long long s = 0;
    int h = 0;
    int last_p = 0;
    vector<int> T_sort;
    for ( int &i : KR[v] ) if ( !ODW[i] ) {
        pair<int,int> p = DFS(i);
        s += p.first; h--;
        T_sort.push_back( p.first );
        T_sort.push_back( p.second );
    }
    sort( T_sort.begin(), T_sort.end() );
    pair<int,int> r;
    long long ww = 1e18;
    for ( int i=0; i<T_sort.size(); i++ ) {
        int p = T_sort[i];
        s += ((long long)h)*(p-last_p);
        h ++;
        if ( s<=ww ) {
            if ( s != ww ) {
                ww = s;
                r.first = r.second = p;
            } else {
                r.second = p;
            }
        }
        last_p = p;
    }
    w += ww;
    return r;
}

int main() {
    scanf("%d%d",&n,&m);
    for ( int i=0; i<n; i++ ) ODW[i]=0;
    for ( int i=0; i<n-1; i++ ) {
        int a,b;
        scanf("%d%d",&a,&b);
        a--; b--;
        KR[a].push_back(b);
        KR[b].push_back(a);
    }
    for ( int i=0; i<m; i++ ) scanf( "%d", T+i );
    if ( n == m ) { printf("%d\n",abs(T[0]-T[1])); return 0; }
    w=0;
    DFS( m );
    printf("%lld\n",w);
    return 0;
}