1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include<bits/stdc++.h>
using namespace std;
vector<int> kraw[500001];
long long wagi[500001];
int odw[500001];
int ojc[500001];
long long pr[500001][2];
long long wyn=0;
void dfs(int x)
{
    odw[x]=1;
    long long sum=0;
    long long no=0,za=0, po=0, ko=0, naj;
    if(kraw[x].size()==1)
    {
        pr[x][0]=wagi[x];
        pr[x][1]=wagi[x];
        return;
    }
    vector<pair <long long , int> > tab;
    for(vector<int> :: iterator it=kraw[x].begin(); it!=kraw[x].end(); it++)
        if(odw[*it]==0)
        {
            dfs(*it);
            sum+=pr[*it][0];
            no++;
            tab.push_back(make_pair(pr[*it][0], 1));
            tab.push_back(make_pair(pr[*it][1], -1));
        }
    sort(tab.begin(), tab.end());
    naj=sum;
    long long ost=0;
    for(int i=0; i<tab.size(); i++){
        long long y=tab[i].first-ost;
        //if(x==6)
        //    printf("? %lld %d\n", tab[i].first, tab[i].second);
        ost=tab[i].first;
        sum+=-no*y+za*y;
        if(sum<naj)
        {
            naj=sum;
            po=ost;
        }
        if(sum<=naj)
            ko=ost;
        if(tab[i].second == 1)
        {
            no--;
        }
        if(tab[i].second == -1)
        {
            za++;
        }
        //if(x==6)
        //   printf("%lld %lld %lld %lld\n", ost, sum, no, za);
    }
    wyn+=naj;
    pr[x][0]=po;
    pr[x][1]=ko;
    return;
}
int main()
{
    int n,m;
    scanf("%d%d", &n, &m);
    for(int i=1; i<n; i++)
    {
        int x,y;
        scanf("%d%d", &x, &y);
        kraw[x].push_back(y);
        kraw[y].push_back(x);
    }
    for(int i=1; i<=m; i++)
        scanf("%lld", &wagi[i]);
    if(n==2)
    {
        if(wagi[1]<wagi[2])
            swap(wagi[1], wagi[2]);
        printf("%lld\n", wagi[1]-wagi[2]);
        return 0;
    }
    dfs(n);
    printf("%lld\n", wyn);
}