1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <vector>

int n, m;

std::vector<int> graph[500005];

int range_left[500005];
int range_right[500005];

long long int result;

int DistanceBetweenPointAndRange(int point, int left, int right) {
  if (point <= left) {
    return left - point;
  } else if (right <= point) {
    return point - right;
  }
  return 0;
}

void Dfs(int w, int father) {
  if ((int) graph[w].size() == 1) {
    return;
  }
  std::vector<int> track_gauges;
  for (int son : graph[w]) {
    if (son != father) {
      Dfs(son, w);
      track_gauges.push_back(range_left[son]);
      track_gauges.push_back(range_right[son]);
    }
  }
  const int above_median = (int) track_gauges.size() / 2;
  const int below_median = above_median - 1;
  std::nth_element(track_gauges.begin(), track_gauges.begin() + below_median,
                   track_gauges.end());
  range_left[w] = track_gauges[below_median];
  std::nth_element(track_gauges.begin(), track_gauges.begin() + above_median,
                   track_gauges.end());
  range_right[w] = track_gauges[above_median];
  for (int son : graph[w]) {
    if (son != father) {
      result += DistanceBetweenPointAndRange(range_left[w],
                                             range_left[son], range_right[son]);
    }
  }
}

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 1; i < n; i++) {
    int a, b;
    scanf("%d%d", &a, &b);
    graph[a].push_back(b);
    graph[b].push_back(a);
  }
  for (int i = 1; i <= m; i++) {
    scanf("%d", &range_left[i]);
    range_right[i] = range_left[i];
  }
  if (n == 2) {
    result = abs(range_left[1] - range_left[2]);
  } else {
    Dfs(n, -1);
  }
  printf("%lld\n", result);
  return 0;
}