1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
// Artur Kraska, II UWr

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <cmath>
#include <list>
#include <set>
#include <map>

#define forr(i, n)                  for(int i=0; i<n; i++)
#define FOREACH(iter, coll)         for(auto iter = coll.begin(); iter != coll.end(); ++iter)
#define FOREACHR(iter, coll)        for(auto iter = coll.rbegin(); iter != coll.rend(); ++iter)
#define lbound(P,R,PRED)            ({typeof(P) X=P,RRR=(R), PPP = P; while(PPP<RRR) {X = (PPP+(RRR-PPP)/2); if(PRED) RRR = X; else PPP = X+1;} PPP;})
#define testy()                     int _tests; scanf("%d", &_tests); FOR(_test, 1, _tests)
#define CLEAR(tab)                  memset(tab, 0, sizeof(tab))
#define CONTAIN(el, coll)           (coll.find(el) != coll.end())
#define FOR(i, a, b)                for(int i=a; i<=b; i++)
#define FORD(i, a, b)               for(int i=a; i>=b; i--)
#define MP                          make_pair
#define PB                          push_back
#define deb(X)                      X;

#define M 1000000007
#define INF 1000000007

using namespace std;

int n, m, a, b;
long long res = 0;

struct miasto
{
    int pocz, kon;
    list <int> l;
};
miasto tab[1000007], e;

void dzialaj(int nr, int o)
{
//    cout << " wchodzi do " << nr << endl;
    if(nr < m)
        return ;

    vector <pair<int, bool> > v;
    v.clear();
    int ile = 0;
    FOREACH(it, tab[nr].l)
        if(*it != o)
        {
            dzialaj(*it, nr);
            v.PB(MP(tab[*it].pocz, 0));
            v.PB(MP(tab[*it].kon, 1));
            ile++;
        }

    sort(v.begin(), v.end());
    int index = ile-1;
    tab[nr].pocz = v[index].first;
    tab[nr].kon = v[index+1].first;

    FOR(i, 0, index-1)
        if(v[i].second)
        {
            res += v[index].first - v[i].first;
        }
    FOR(i, index, ile*2-1)
        if(!v[i].second)
        {
            res += v[i].first - v[index].first;
        }
//    cout << "  nr " << nr << " ma pocz: " << tab[nr].pocz << " i kon: " << tab[nr].kon << ", res: " << res << endl;
}

int main()
{
    scanf("%d %d", &n, &m);
    forr(i, n-1)
    {
        scanf("%d %d", &a, &b);
        a--;
        b--;
        tab[a].l.PB(b);
        tab[b].l.PB(a);
    }
    forr(i, m)
    {
        scanf("%d", &tab[i].pocz);
        tab[i].kon = tab[i].pocz;
    }

    if(m == n)
    {
        forr(i, n)
            FOREACH(it, tab[i].l)
                res += abs(tab[i].pocz - tab[*it].pocz);
        res /= 2;
    }
    else
        dzialaj(m, -1);

    printf("%lld\n", res);

    return 0;
}