1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <cstdio>
#include <vector>
#include <algorithm>
#include <utility>

const int MAX_N = 500009;
const int MAX_M = MAX_N;

int n, m;
std::vector<int> adj[MAX_N];
int r[MAX_M];

std::pair<std::pair<int, int>, long long> dfs(int i = 0, int p = -1) {
  if (i < m) {
    return std::make_pair(std::make_pair(r[i], r[i]), 0ll);
  }
  long long cost = 0;
  std::vector<std::pair<int, bool>> intervals;
  for (int j : adj[i]) {
    if (j != p) {
      auto r = dfs(j, i);
      intervals.push_back(std::make_pair(r.first.first, false));
      intervals.push_back(std::make_pair(r.first.second, true));
      cost += r.second;
    }
  }
  int s = intervals.size();
  auto it = intervals.begin();
  std::advance(it, s/2-1);
  std::nth_element(intervals.begin(), it, intervals.end());
  int l = it->first;
  std::nth_element(intervals.begin(), ++it, intervals.end());
  int h = it->first;
  
  //std::sort(intervals.begin(), intervals.end());
  //int l = intervals[s/2-1].first;
  //int h = intervals[s/2].first;
  for (auto itr : intervals) {
   if (itr.first < l) {
     if (itr.second)
       cost += l-itr.first;
   } else if (!itr.second) {
     cost += itr.first-l;
   }
  }
  return std::make_pair(std::make_pair(l, h), cost);
}

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 0; i < n-1; ++i) {
    int u, v;
    scanf("%d%d", &u, &v);
    adj[u-1].push_back(v-1);
    adj[v-1].push_back(u-1);
  }
  for (int i = 0; i < m; ++i) {
    scanf("%d", &r[i]);
  }
  if (m == 2) {
    printf("%d\n", r[0] > r[1] ? r[0]-r[1] : r[1]-r[0]);
    return 0;
  }
  printf("%lld\n", dfs(m).second);
}