1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
typedef long long i64;

struct Problem
{
  Problem(int n, int m) : n_(n), m_(m), adj_(n), r_(n,{-1,-1}) {}
  
  void addEdge(int u, int v) { adj_[u].push_back(v); adj_[v].push_back(u); }
  i64 solve();
  
  int n_,m_;
  vector<vector<int>> adj_;
  vector<pair<int,int>> r_;
};

i64 Problem::solve()
{
  vector<int> deg(n_);
  for (int i=0; i<n_; ++i)
    deg[i] = (int)adj_[i].size();
  queue<int> q;
  for (int i=0; i<m_; ++i)
    if (--deg[adj_[i][0]] == 1)
      q.push(adj_[i][0]);
  vector<int> st;
  while (!q.empty())
  {
    int v = q.front();
    q.pop();
    st.push_back(v);
    vector<int> ar;
    bool all = true;
    for (int u : adj_[v])
    {
      if (r_[u].first == -1)
      {
        all = false;
        if (--deg[u] == 1)
          q.push(u);
        continue;
      }
      ar.push_back(r_[u].first);
      ar.push_back(r_[u].second);
    }
    sort(ar.begin(), ar.end());
    if (all || ar.size()%2 == 1)
      r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]};
    else
      r_[v] = {ar[((int)ar.size()-1)/2], ar[(int)ar.size()/2]};
  }
  for (auto it = st.rbegin(); it != st.rend(); ++it)
  {
    int v = *it;
    vector<int> ar;
    for (int u : adj_[v])
    {
      ar.push_back(r_[u].first);
      ar.push_back(r_[u].second);
    }
    sort(ar.begin(), ar.end());
    r_[v] = {ar[(int)ar.size()/2], ar[(int)ar.size()/2]};    
  }
  i64 result = 0;
  for (int i=0; i<n_; ++i)
    for (int v : adj_[i])
      if (r_[i].first > r_[v].first)
        result += r_[i].first - r_[v].first;
  return result;
}

int main()
{
  int n,m;
  scanf("%d%d", &n, &m);
  Problem p(n,m);
  for (int i=0; i<n-1; ++i)
  {
    int u,v;
    scanf("%d%d", &u, &v);
    p.addEdge(u-1,v-1);
  }
  for (int i=0; i<m; ++i)
  {
    int x;
    scanf("%d", &x);
    p.r_[i] = {x,x};
  }
  printf("%lld\n", p.solve());
}