1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <cstdio>
#include <list>
#include <queue>
#include <vector>
#include <algorithm>
#include <cmath>

const int MAX = 500001;

int n, m, d, a, b;

struct Node;

struct Edge{
	//krawedz z v->w
	int v, w;
//	Edge *rev;  // wska?nik na kraw?d? wsteczn?
	Edge(int v_in, int w_in){
		v = v_in;
		w = w_in;
	}
};

struct Vertex{
	int x = 0; //numer wierzcholka
	int tory = 0;
//	int bfs = 0;
	bool visited;
	Vertex(int x_in){
		visited = false;
		x = x_in; 
	}
	Vertex(){
		visited = false;
		x = 0; 
	}
};

Vertex V[MAX+1]; 

std::list<Edge> E[MAX+1];


int vec[MAX+1];


long long int dfs(int v, int size){
	V[v].visited = true;
	int mysize = size;
	long long int sum = 0;
	if(v > m){
	  	for (std::list<Edge>::iterator j = E[v].begin(); j != E[v].end(); ++j) {
	    	if(!(V[j->w].visited)){ 
	    		sum+=dfs(j->w, mysize+1);
	      		vec[mysize] = V[j->w].tory;
	      		mysize++;
	    	}
	  	}
		std::sort(vec+size,vec+mysize);
   // 	printf("\n");
    	int med = 0;
    	int vsize = (mysize-size);
    	if(vsize % 2){
    		med = vec[size+vsize/2];
    //		printf("index: %d, med1: %d\n", index, med);
		}else{
			med = vec[size+(vsize/2)-1];
	//		printf("index: %d, med2: %d\n", index, med);
		}
    //	printf("index: %d, med: %d\n", index, med);
    	for(int j = size; j < mysize; ++j){
    		if(med > vec[j]){
    			sum+= med-vec[j];
			}else{
				sum+= vec[j]-med;
			}
    		
		}
		V[v].tory = med;
	} 
	//printf("index: , suma");
	return sum;
}


int main(){
	scanf("%d %d", &n, &m);
	for(int i = 1; i <= n; ++i){
		V[i] = Vertex(i);
	}
	for(int i = 1; i < n; ++i){
		scanf("%d %d", &a, &b);
		E[a].push_back(Edge(a,b));
		E[b].push_back(Edge(b,a));
	//	E[a].back().rev = &E[b].back();
    //  	E[b].back().rev = &E[a].back();
	}
	for(int i = 1; i <= m; ++i){
		scanf("%d", &a);
		V[i].tory = a;
	}
	
	long long int suma = 0;
	if(n>m){
		suma = dfs(m+1,0);
	}
	for(int i = 1; i <= m; ++i){
		if(!V[i].visited){
			long long int min = V[i].tory;
			long long int max = V[i].tory;
			for (std::list<Edge>::iterator j = E[i].begin(); j != E[i].end(); ++j) {
	    		if(!(V[j->w].visited)){ 
	    			if(min > V[j->w].tory ) min = V[j->w].tory;
	    			if(max < V[j->w].tory ) max = V[j->w].tory;
	    			V[j->w].visited = true;
	    		}
    		}
    		suma+=max-min;
		}
	}
	
	
	printf("%lld", suma);
	
	//printf("\n");
	//for(int i = 1; i <= n; ++i){
//		printf(" i: %d, bfs: %d\n", i, V[i].bfs);
//	}
}