1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <vector>
#include <set>
using namespace std;

const int MAX = 500000;

#define PB push_back

vector<int> G[MAX];
int W[MAX];
int n, k;
pair<int, int> P[MAX];
vector<int> tmp;

pair<int, int> dfs1(int u, int p) {
  if (G[u].size() == 1) {
    return {W[u], W[u]};
  }

  for (int v : G[u])
    if (v != p) {
      P[v] = dfs1(v, u);
    }

  tmp.clear();
  for (int v : G[u])
    if (v != p) {
      tmp.PB(P[v].first);
      tmp.PB(P[v].second);
    }
  sort(tmp.begin(), tmp.end());
  return {tmp[(tmp.size() - 1) / 2], tmp[tmp.size() / 2]};
}

void dfs2(int u, int p) {
  for (int v : G[u])
    if (v != p && G[v].size() != 1) {
      if (W[u] >= P[v].second) W[v] = P[v].second;
      else if(W[u] <= P[v].first) W[v] = P[v].first;
      else W[v] = W[u];
      dfs2(v, u);
    }
}

int main() {
  scanf("%d %d", &n, &k);
  for (int i=1;i<n;i++)
  {
    int a, b;
    scanf("%d %d", &a, &b);
    a--, b--;
    G[a].PB(b);
    G[b].PB(a);
  }
  for (int i=0;i<k;i++)
    scanf("%d", W+i);
  if(n == 2 && k == 2) {
    printf("%d\n", abs(W[0] - W[1]));
    return 0;
  }
  W[k] = dfs1(k, -1).first;
  dfs2(k, -1);
  long long res = 0;
  for (int u=0;u<n;u++)
    for (int v : G[u])
      if (v < u)
        res += abs(W[u] - W[v]);
  printf("%lld\n", res);
  return 0;
}