1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
//Pawel Kura
#ifdef DEBUG
const bool D = true;
#else
const bool D = false;
#endif

#include <cstdio>
#include <vector>
#include <cstdlib>
#include <algorithm>
using namespace std;
typedef long long LL;
int n, m;

pair<pair<int,int>, LL> doit(const vector<pair<int,int>> & intervals) {
    vector<pair<int, int>> events;
    for (auto p: intervals) {
        events.push_back({p.first, 1});
        events.push_back({p.second, -1});
    }

    sort(begin(events), end(events));

    int left = 0;
    int right = intervals.size();
    int prev = 0;
    LL sum = 0;
    for (int i = 0; i < intervals.size(); i++) {
        sum += intervals[i].first;
    }
    if (D)
        fprintf(stderr, "%d %lld\n", prev, sum);
    for (int i = 0; i < events.size(); ) {
        if (left > right)
            return {{prev, prev}, sum};
        if (left == right) 
            return {{prev, events[i].first}, sum};
        sum += 1LL * (events[i].first - prev) * (left - right);
        prev = events[i].first;
        while (i < events.size() && events[i].first == prev) {
            if (events[i].second == 1)
                right--;
            else
                left++;
            i++;
        }
    }
    return {{prev, prev}, sum};
}


const int MAXN = 500000;
vector<int> G[MAXN];
bool vis[MAXN];
pair<int, int> ints[MAXN];
LL RES;

void dfs(int x, int p = -1) {
    vis[x] = true;
    for (auto y: G[x]) {
        if (y != p) {
            dfs(y, x);
        }
    }

    if (x < m) return;

    vector<pair<int,int>> intervals;

    for (auto y: G[x]) {
        if (y != p) intervals.push_back(ints[y]);
    }
    auto r = doit(intervals);
    ints[x] = r.first;
    RES += r.second;
}

void go() {
    scanf("%d %d", &n, &m);
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        scanf("%d %d", &x, &y);
        x--; y--;
        G[x].push_back(y);
        G[y].push_back(x);
    }
    for (int i = 0; i < m; i++) {
        int x;
        scanf("%d", &x);
        ints[i] = {x, x};
    }

    if (n == 2 && m == 2) {
        printf("%d\n", abs(ints[0].first - ints[1].first));
        return;
    }
    dfs(m);

    printf("%lld\n", RES);
}

int main(int argc, char **argv) {
    go();
    return 0;
}