1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef pair<int,int> pii;
int n,m,x,y,i,a[500500],le[500500],ri[500500];
vector<int> g[500500];
long long res;
void dfs(int i, int p) {
  vector<pii> all;
  if (i<=m) {
    le[i]=a[i];
    ri[i]=a[i];
    return;
  }
  int x=0,y=0;
  for (int j=0; j<g[i].size(); j++) {
    int k=g[i][j];
    if (k==p) continue;
    dfs(k,i);
    all.push_back(make_pair(le[k],0));
    all.push_back(make_pair(ri[k],1));
    y++;
  }
  sort(all.begin(),all.end());
  for (int ii=0, j=0; ii<all.size(); ii=j) {
    for (j=ii; j<all.size() && all[ii].first==all[j].first; j++) if (all[j].second) x++; else y--;
    if (x==y && j<all.size()) {
      le[i]=all[ii].first;
      ri[i]=all[j].first;
      break;
    }
    if (x>=y) {
      le[i]=ri[i]=all[ii].first;
      break;
    }
  }
  for (int j=0; j<g[i].size(); j++) {
    int k=g[i][j];
    if (k==p) continue;
    if (ri[k]<le[i]) res+=le[i]-ri[k];
    if (le[k]>le[i]) res+=le[k]-le[i];
  }
}
int main() {
  scanf("%d%d",&n,&m);
  for (i=1; i<n; i++) {
    scanf("%d%d",&x,&y);
    g[x].push_back(y);
    g[y].push_back(x);
  }
  for (i=1; i<=m; i++) scanf("%d",&a[i]);
  if (n==2) {
    printf("%d\n",abs(a[1]-a[2]));
    return 0;
  }
  dfs(n,0);
  printf("%lld\n",res);
  return 0;
}