1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
#include<algorithm>
#include<cstdio>
#include<queue>
#include<utility>
#include<vector>

#define VERTICES 500000

#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define REP(i, n) FOR(i, 0, n)
using namespace std;

vector<int> graph[VERTICES];
int deg[VERTICES];
queue<int> Q;
int r[VERTICES];
pair<int, int> intervals[VERTICES];
int parent[VERTICES];

int main() {
  int n, m, a, b;
  scanf("%d%d", &n, &m);
  REP(i, n-1) {
    scanf("%d%d", &a, &b);
    --a;
    --b;
    graph[a].push_back(b);
    graph[b].push_back(a);
  }
  REP(i, n) {
    deg[i] = graph[i].size();
    parent[i] = i;
  }
  REP(i, m) {
    Q.push(i);
    scanf("%d", r + i);
  }
  if ( ( n == 2) && ( m == 2) ) {
    printf("%d\n", (max(r[0], r[1]) - min(r[0], r[1])));
    return 0;
  }
  long long cost = 0;
  while (!Q.empty()) {
    int element = Q.front();
    Q.pop();
    
    for (int neighbour : graph[element]) {
      deg[neighbour]--;
      if (deg[neighbour] == 1) {
        Q.push(neighbour);
      }
      if (parent[neighbour] == neighbour) {
        parent[element] = neighbour;
        // printf("parent[%d] = %d\n", element + 1, neighbour + 1);
      }
    }
    
    if (element < m) {
      intervals[element] = make_pair(r[element], r[element]);
    } else {
      vector<int> points;
      for (int neighbour : graph[element]) {
        if (parent[neighbour] == element) {
          points.push_back(intervals[neighbour].first);
          points.push_back(intervals[neighbour].second);
        }
      }
      sort(points.begin(), points.end());
      intervals[element] =
          make_pair(points[(points.size() - 1) / 2], points[points.size() / 2]);
      //printf("Creating interval for element(%d): [%d, %d]\n",
      //    element + 1,
      //    points[(points.size() - 1) / 2],
      //    points[points.size() / 2]);
      
      for (int neighbour : graph[element]) {
        if (parent[neighbour] == element) {
          if (intervals[neighbour].second < intervals[element].first) {
            cost += intervals[element].first - intervals[neighbour].second;
            //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n",
            //    cost,
            //    intervals[element].first,
            //    intervals[neighbour].second,
            //    neighbour + 1,
            //    element + 1);
          }
          if (intervals[neighbour].first > intervals[element].first) {
            cost += intervals[neighbour].first - intervals[element].first;
            //printf("new cost: %lld: +%d - %d; comes from son(%d) of element(%d)\n",
            //    cost,
            //    intervals[neighbour].first,
            //    intervals[element].first,
            //    neighbour + 1,
            //    element + 1);
          }
        }
      }
    }
  }
  printf("%lld\n", cost);
  return 0;
}