1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

long long int T[500001][4];

int main()
{
    int m,n;
    cin >> n >> m;
    vector < vector <int> >sasiedzi(n+1,vector<int>() );
    int x,y;
    for (int i=0; i<n-1; i++){
        cin >> x >> y;
        sasiedzi[x].push_back(y);
        sasiedzi[y].push_back(x);
        T[x][2]++;
        T[y][2]++;
    }

    int r;
    for (int i=0; i<m; i++){
        cin >> r;
        T[i+1][3] = 1;
        T[i+1][0] = r;
        T[i+1][1] = r;
    }

    vector <int>liscie;
    vector <int>rodz;
    vector <int>rodz2;
    vector <int>rodz3;
    for (int i=1; i<n+1; i++){
        if (T[i][2] == 1 && T[sasiedzi[i][0]][2] > 1){
            rodz2.push_back(sasiedzi[i][0]);
            T[sasiedzi[i][0]][2]--;
            T[i][2] = 0;
        }
    }

    for (int i=0; i<rodz2.size(); i++){
        if (T[rodz2[i]][2] == 1  && T[rodz2[i]][3] == 0){
             rodz.push_back(rodz2[i]);
             T[rodz2[i]][3] = 1;
        }
    }

    long long int suma = 0;
    rodz3 = rodz;


    int liczba = n-m;
    while (liczba > 0){
        vector <int> dzie;
        rodz2.clear();
        rodz3 = rodz;
        for (int i=0; i<rodz.size(); i++){
            for (int j=0; j<sasiedzi[rodz[i]].size(); j++){
                if (T[sasiedzi[rodz[i]][j]][2] == 0){
                    dzie.push_back(sasiedzi[rodz[i]][j]);
                }
            }
            int pocz[dzie.size()];
            int kon[dzie.size()];
            int all[2*dzie.size()];
            for (int j=0; j<dzie.size(); j++){
                pocz[j] = T[dzie[j]][0];
                kon[j] = T[dzie[j]][1];
                all[2*j] = T[dzie[j]][0];
                all[2*j+1] = T[dzie[j]][1];
            }
            sort(&pocz[0],&pocz[dzie.size()]);
            sort(&kon[0],&kon[dzie.size()]);
            sort(&all[0],&all[2*dzie.size()]);

            int p,k;
            int mn[2*dzie.size()][2];
            int wsp = 0;
            int wsk = 0;

            for (int j=0; j<2*dzie.size(); j++){
                while(wsk < dzie.size() && kon[wsk] <= all[j]) wsk++;
                mn[j][0] = wsk;
            }
            for (int j=2*dzie.size()-1; j>=0; j--){
                while (wsp < dzie.size() && pocz[dzie.size()-1-wsp] >= all[j]) wsp++;
                mn[j][1] = wsp;
            }

            p = 0;
            k = 2*dzie.size()-1;
            while(k-p>1){
                if (mn[p][0] > mn[k][1]) k--;
                if (mn[p][0] < mn[k][1]) p++;
                if (mn[p][0] == mn[k][1]){
                    k--;
                    p++;
                }
            }
            if (mn[p][0] > mn[k][1]) k--;
            if (mn[p][0] > mn[k][1]) p++;
            T[rodz[i]][0] = all[p];
            T[rodz[i]][1] = all[k];
            T[rodz[i]][2] = 1;
            for (int j=0; j<dzie.size(); j++){
                if (!(all[p] >= T[dzie[j]][0] && all[p] <= T[dzie[j]][1])){
                    suma += min(abs(all[p] - T[dzie[j]][0]),abs(all[p] - T[dzie[j]][1]));
                }
            }

            dzie.clear();
        }

        liczba -= rodz.size();
        for (int i=0; i<rodz.size(); i++){
            for (int j=0; j<sasiedzi[rodz[i]].size(); j++){
                if (T[sasiedzi[rodz[i]][j]][2] > 1){
                    rodz2.push_back(sasiedzi[i][0]);
                    T[sasiedzi[rodz[i]][j]][2]--;
                }
            }
        }
        for (int j=0; j<rodz.size(); j++) T[rodz[j]][2] = 0;
        rodz.clear();
        for (int i=0; i<rodz2.size(); i++){
            if (T[rodz2[i]][2] == 1  && T[rodz2[i]][3] == 0){
                 rodz.push_back(rodz2[i]);
                 T[rodz2[i]][3] = 1;
            }
        }
        rodz = rodz2;
    }
    if (rodz3.size() > 1){
        int x1 = rodz3[0];
        int x2 = rodz3[1];
        if (T[x1][0] > T[x2][1]) suma += T[x1][0] - T[x2][1];
        if (T[x2][0] > T[x1][1]) suma += T[x2][0] - T[x1][1];
    }

    cout << suma;
}