1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <list>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
#include <cmath>
#include <cstring>
#include <string>
#include <iostream>
#include <complex>
#include <sstream>
#include <cassert>
using namespace std;
 
typedef long long LL;
typedef unsigned long long ULL;
typedef long double LD;
typedef vector<int> VI;
typedef pair<int,int> PII;
 
#define REP(i,n) for(int i=0;i<(n);++i)
#define SIZE(c) ((int)((c).size()))
#define FOR(i,a,b) for (int i=(a); i<(b); ++i)
#define FOREACH(i,x) for (__typeof((x).begin()) i=(x).begin(); i!=(x).end(); ++i)
#define FORD(i,a,b) for (int i=(a)-1; i>=(b); --i)
#define ALL(v) (v).begin(), (v).end()
 
#define pb push_back
#define mp make_pair
#define st first
#define nd second

int N, M;
list<int> adj[500005];
int R[500005];

LL mini[500005];
PII range[500005];

void go(int v, int p = -1) {
  if (v < M) {
    mini[v] = 0;
    range[v].st = range[v].nd = R[v];
    return;
  }

  FOREACH(it, adj[v]) {
    if (*it == p) continue;
    go(*it, v);
  }
  
  vector<int> pts;
  FOREACH(it, adj[v]) {
    if (*it == p) continue;
    pts.pb(range[*it].st);
    pts.pb(range[*it].nd);
  }
  
  sort(pts.begin(), pts.end());
  int K = pts.size() / 2;
  int val = pts[K-1];
  mini[v] = 0;
  FOREACH(it, adj[v]) {
    if (*it == p) continue;
    mini[v] += mini[*it] + max(0, range[*it].st - val) + max(0, val - range[*it].nd);
  }
  range[v].st = pts[K-1];
  range[v].nd = pts[K];
}

int main() {
  scanf("%d%d", &N, &M);
  REP(i,N-1) {
    int a, b;
    scanf("%d%d", &a, &b);
    --a, --b;
    adj[a].pb(b);
    adj[b].pb(a);
  }
  REP(i,M) {
    scanf("%d", &R[i]);
  }
  
  if (M == N) {
    printf("%d\n", abs(R[1] - R[0]));
    return 0;
  }

  go(N-1);
  printf("%lld\n", mini[N-1]);
}