1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <bits/stdc++.h>

using namespace std;

typedef long long LL;

template<typename TH>
void debug_vars(const char* data, TH head){
    cerr << data << "=" << head << "\n";
}

template<typename TH, typename... TA>
void debug_vars(const char* data, TH head, TA... tail){
    while(*data != ',') cerr << *data++;
    cerr << "=" << head << ",";
    debug_vars(data+1, tail...);
}

#ifdef LOCAL
#define debug(...) debug_vars(#__VA_ARGS__, __VA_ARGS__)
#else
#define debug(...) (__VA_ARGS__)
#endif

/////////////////////////////////////////////////////////


const int MaxN = 500005;

vector<int> adj[MaxN];
int value[MaxN];
bool visited[MaxN];
int N, M;

vector<int> modifiers[MaxN];
LL increase[MaxN];

void input(){
    scanf("%d%d", &N, &M);
    for(int i = 0; i < N-1; i++){
        int u, v;
        scanf("%d%d", &u, &v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    for(int i = 1; i <= M; i++){
        scanf("%d", &value[i]);
    }
}

void dfs(int v){
    visited[v] = true;
    increase[v] = 0;
    modifiers[v].clear();

    if(v <= M){
        modifiers[v] = {value[v], value[v]};
        return;
    }

    for(int s : adj[v]){
        if(visited[s]) continue;
        dfs(s);
        increase[v] += increase[s];
        assert(modifiers[s].size() == 2);
        for(int x : modifiers[s]) modifiers[v].push_back(x);
    }

    sort(modifiers[v].begin(), modifiers[v].end());
    int pos1 = (int)modifiers[v].size() / 2 - 1,
        pos2 = pos1+1,
        mid1 = modifiers[v][pos1],
        mid2 = modifiers[v][pos2];

    LL prevSumDist = 0;
    for(int x : modifiers[v]) prevSumDist += abs(x-mid1);
    LL nowSumDist = abs(mid1-mid2);

    increase[v] += prevSumDist-nowSumDist;
    modifiers[v].clear();
    modifiers[v].shrink_to_fit();
    modifiers[v] = {mid1, mid2};

//    debug(v);
//    for(int x : modifiers[v]) debug(x);
//    debug(increase[v]);
}

int main(){
    input();
    if(N == 2){
        printf("%d\n", abs(value[1]-value[2]));
        return 0;
    }
    assert(N > M);
    dfs(N);

    assert(modifiers[N].size() == 2);
    printf("%lld\n", (abs(modifiers[N][0]-modifiers[N][1]) + increase[N])/2);
}