1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <bits/stdc++.h>

using namespace std;

#define PB push_back
#define FORE(i, t) for(__typeof(t.begin())i=t.begin();i!=t.end();++i)
#define SZ(x) int((x).size())
#define REP(i, n) for(int i=0,_=(n);i<_;++i)
#define FOR(i, a, b) for(int i=(a),_=(b);i<=_;++i)
#define FORD(i, a, b) for(int i=(a),_=(b);i>=_;--i)

typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pii;

const int INF = 1e9 + 9;
const int MAX_N = 500003;

//struct Range {
////    ll cost;
//    int a, b;
//};

ll diff(ll a, ll b) {
    ll result = a - b;
    if (result < 0) {
        return -result;
    }
    return result;
}

int n, m;

int r[MAX_N];
int children[MAX_N];
vi adj[MAX_N];
bool vis[MAX_N];
pii ranges[MAX_N];


void add_edge(int a, int b) {
    adj[a].PB(b);
    adj[b].PB(a);
}

pii f(int x) {
    vis[x] = true;
    if (x <= m) {
        return ranges[x] = pii(r[x], r[x]);
    }
//    vector <pii> ranges;
    vi points;
    FORE(yt, adj[x]) {
        int y = *yt;
        if (vis[y]) {
            continue;
        }
        ++children[x];
        pii range = f(y);
        points.PB(range.first);
        points.PB(range.second);
    }
    sort(points.begin(), points.end());
    return ranges[x] = pii(points[SZ(points) / 2 - 1], points[SZ(points) / 2]);
}

ll solve(int x, int val) {
//    printf("x=%d val=%d\n", x, val);
    ll result = 0;
    if (val > ranges[x].second) {
        result += diff(val, ranges[x].second);// * (ll) children[x];
        val = ranges[x].second;
    }
    if (val < ranges[x].first) {
        result += diff(val, ranges[x].first);// * (ll) children[x];
        val = ranges[x].first;
    }
//    printf("x=%d start=%lld new_val=%d\n", x, result, val);
    vis[x] = true;
    FORE(yt, adj[x]) {
        int y = *yt;
        if (vis[y]) {
            continue;
        }
        result += solve(y, val);
    }
//    printf("x=%d end=%lld\n", x, result);
    return result;
}

void inline one() {
    scanf("%d%d", &n, &m);
    REP (i, n - 1) {
        int a, b;
        scanf("%d%d", &a, &b);
        add_edge(a, b);
    }
    FOR (i, 1, n) {
        if (i <= m) {
            int x;
            scanf("%d", &x);
            r[i] = x;
        } else {
            r[i] = -1;
        }
        vis[i] = false;
        children[i] = 0;
    }
    if (n == m) {
        ll result = diff(r[1], r[2]);
        printf("%lld\n", result);
        return;
    }
    int root = n;
    pii range = f(root);
//    printf("%d..%d\n", range.first, range.second);
    FOR (i, 1, n) {
        vis[i] = false;
    }
    ll result = solve(root, range.first);
    printf("%lld\n", result);
}

int main() {
    //int z; scanf("%d", &z); while(z--)
    one();
}