1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <iomanip>
#include <iostream>
#include <utility>
#include <algorithm>
#include <cassert>
#include <string>
#include <vector>
#include <set>
#include <map>

using namespace std;

#define ALL(x) x.begin(), x.end()
#define VAR(a,b) __typeof (b) a = b
#define IN(a) int a; cin >> a
#define IN2(a,b) int a, b; cin >> a >> b
#define REP(i,n) for (int _n=(n), i=0; i<_n; ++i)
#define FOR(i,a,b) for (int _b=(b), i=(a); i<=_b; ++i)
#define FORD(i,a,b) for (int _b=(b), i=(a); i>=_b; --i)
#define FORE(i,a) for (VAR(i,a.begin ()); i!=a.end (); ++i) 
#define PB push_back
#define MP make_pair
#define ST first
#define ND second

typedef vector<int> VI;
typedef long long LL;
typedef pair<int,int> PII;
typedef double LD;

const int DBG = 0, INF = int(1e9);

vector<VI> v;
VI r;
int n, m;

LL res = 0;
vector<PII> res_range;

void dfs(int node, int parent) {
  if (node < m) { res_range[node] = MP(r[node], r[node]); return; }
  vector<PII> children;
  children.reserve(v[node].size());
  FORE(it, v[node])
    if (*it != parent)
      children.PB(res_range[*it]);
  LL beg_sum = 0, end_sum = 0;
  int begs_left = children.size(), ends_cnt = 0;
  vector<PII> events;
  FORE(it, children) {
    events.PB(MP(it->first, 1));
    beg_sum += it->first;
    events.PB(MP(it->second, 0));
  }
  sort(ALL(events));
  LL best_val = beg_sum;
  int best_beg = 0, best_end = 0;
  FORE(it, events) {
    int pos = it->first, type = it->second;
    LL cur = beg_sum - LL(begs_left) * pos + LL(ends_cnt) * pos - end_sum;
    if (cur < best_val) {
      best_val = cur;
      best_beg = best_end = pos;
    }
    else if (cur == best_val) best_end = pos;
    if (type == 1) {
      beg_sum -= pos;
      begs_left--;
    }
    else {
      end_sum += pos;
      ends_cnt++;
    }
  }
  res += best_val;
  res_range[node] = MP(best_beg, best_end);
}

VI vis;

struct state {
  int node, parent, first_vis;
  state(int node, int parent, int first_vis) : node(node), parent(parent), first_vis(first_vis) {}
};

int main() {
   ios_base::sync_with_stdio(0);
   cout.setf(ios::fixed);
   cin >> n >> m;
   v.resize(n);
   REP(i, n  - 1) {
     IN2(a, b);
     --a;
     --b;
     v[a].PB(b);
     v[b].PB(a);
   }
   r.resize(m);
   REP(i,m) cin >> r[i];
   if (n == 2) {
     assert(m == 2);
     cout << abs(r[0] - r[1]) << endl;
     return 0;
   }
   vis = VI(n, 0);
   vector<state> st;
   st.PB(state(m, -1, 1));
   vis[m] = 1;
   res_range.resize(n);
   while (!st.empty()) {
     int nxt = st.back().node, parent = st.back().parent, first_vis = st.back().first_vis;
     st.pop_back();
     if (first_vis) {
       st.PB(state(nxt, parent, 0));
       FORE(it, v[nxt])
         if (*it != parent && !vis[*it]) {
           vis[*it] = 1;
           st.PB(state(*it, nxt, 1));
         }
     }
     else {
       dfs(nxt, parent);
     }
   }
   cout << res << endl;
   return 0;
}