1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <vector>
#include <iostream>
#include <sstream>
#include <math.h>
#include <sys/time.h>
#include <cstdlib>
#include <algorithm>
#include <cassert>
#include <cstring>
#include <fstream>
#include <set>
#include <climits>

#define FOR(i,a,b)  for(__typeof(b) i=(a);i<(b);++i)
#define REP(i,a)    FOR(i,0,a)
#define FOREACH(x,c)   for(__typeof(c.begin()) x=c.begin();x != c.end(); x++)
#define ALL(c)      c.begin(),c.end()
#define CLEAR(c)    memset(c,0,sizeof(c))
#define SIZE(c) (int) ((c).size())

#define PB          push_back
#define MP          make_pair
#define X           first
#define Y           second

#define ULL         unsigned long long
#define LL          long long
#define LD          long double
#define II         pair<int, int>
#define DD         pair<double, double>

#define VC	    vector
#define VI          VC<int>
#define VVI         VC<VI>
#define VD          VC<double>
#define VS          VC<string>
#define VII         VC<II>
#define VDD         VC<DD>

#define DB(a)       cerr << #a << ": " << a << endl;

using namespace std;

template<class T> void print(VC < T > v) {cerr << "[";if (SIZE(v) != 0) cerr << v[0]; FOR(i, 1, SIZE(v)) cerr << "," << v[i]; cerr << "]\n";}
template<class T> string i2s(T &x) { ostringstream o; o << x; return o.str(); }
VS split(string &s, char c = ' ') {VS all; int p = 0, np; while (np = s.find(c, p), np >= 0) {if (np != p) all.PB(s.substr(p, np - p)); p = np + 1;} if (p < SIZE(s)) all.PB(s.substr(p)); return all;}

int n,m;

VVI G;
VI deg;
VII w;

LL total=0;

void readData(){
    scanf("%d %d",&n,&m);
    G.resize(n);
    deg.resize(n,0);
    int u,v;
    REP(i,n-1){
        scanf("%d %d",&u,&v);
        G[u-1].PB(v-1);
        G[v-1].PB(u-1);
        deg[u-1]++; deg[v-1]++;
    }
    w.resize(n,MP(0,0));
    REP(i,m){
        scanf("%d",&(w[i].X));
        w[i].Y = w[i].X;
    }
}

void solve(){
    if (m == 2 && n == 2){
        total = abs(w[0].X-w[1].X);
        return;
    }
    VI stack;
    REP(i,m) stack.PB(i);
    VII events;
    REP(i,n){
        int v = stack.back(); stack.pop_back();
        LL lSum=0, rSum=0;
        LL lCnt=0, rCnt=0;
        events.clear();
        for(auto ngb : G[v]) 
            if (w[ngb].X == 0){
                deg[ngb]--;
                if (deg[ngb] == 1)
                    stack.PB(ngb);
            } else{
                events.PB(MP(w[ngb].X,-1)); 
                rSum += w[ngb].X;
                rCnt++;
                events.PB(MP(w[ngb].Y,1));
            }
        sort(ALL(events));
        LL bestPen = LLONG_MAX;
        if (SIZE(G[v]) != 1){
            int bestL=0,bestR=0;
            for(auto &e : events){
                if (e.Y == -1){
                    rSum -= e.X;
                    rCnt--;
                }else{
                    lSum += e.X;
                    lCnt++;
                }
                LL pen = (e.X*lCnt-lSum) + (rSum-e.X*rCnt);
                if (pen < bestPen){
                    bestPen = pen;
                    bestL = bestR = e.X;
                } else if (pen == bestPen){
                    bestR = e.X;
                }
            }
            total += bestPen;
            w[v] = MP(bestL,bestR);
        }
    }
}

int main(){
    readData();
    solve();
    printf("%lld\n",total);
    return 0;
}