1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>
using namespace std;
#define REP(i, n) for(int i = 0; i < (n); i++)
#define SIZE(s) ((int) (s).size())
#define SCAND(x) assert(scanf("%d", x) == 1)
#define SCANS(x) assert(scanf("%s", x) == 1)
#define ALL(s) s.begin(), s.end()
#define MP make_pair
#define ST first
#define ND second

using LL = long long;
using PII = pair<int, int>;

const int N = 5e5;
int n, m;
vector<int> g[N];
int r[N];

PII ranges[N];

void dfs1(int u, int p = -1){
  if(u < m){
    ranges[u] = MP(r[u], r[u]);
    return;
  }
  vector<PII> events;
  for(int v : g[u]) if(v != p){
    dfs1(v, u);
    events.emplace_back(ranges[v].ST, -1);
    events.emplace_back(ranges[v].ND, +1);
  }
  int k = events.size() / 2;
  nth_element(events.begin(), events.begin() + k - 1, events.end());
  nth_element(events.begin() + k, events.begin() + k, events.end());
  ranges[u] = MP(events[k-1].ST, events[k].ST);
}

void dfs2(int u, int p = -1){
  if(u < m) return;
  if(p == -1){
    r[u] = ranges[u].ST;
  } else {
    r[u] = r[p];
    if(r[u] < ranges[u].ST) r[u] = ranges[u].ST;
    if(r[u] > ranges[u].ND) r[u] = ranges[u].ND;
  }
  for(int v : g[u]) if(v != p){
    dfs2(v, u);
  }
}

int main(){
  assert(scanf("%d%d", &n, &m) == 2);
  REP(i, n-1){
    int u, v;
    assert(scanf("%d%d", &u, &v) == 2);
    u--; v--;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  REP(i, m) assert(scanf("%d", &r[i]) == 1);
  if(m == n){
    printf("%d\n", abs(r[0] - r[1]));
    return 0;
  }
  dfs1(m);
  dfs2(m);
  LL res = 0;
  REP(u, n){
    for(int v : g[u]) if(u < v) res += abs(r[u] - r[v]);
  }
  printf("%lld\n", res);
}