#include <cstdio> #include <list> #include <queue> #include <vector> #include <algorithm> #include <cmath> const int MAX = 500001; int n, m, d, a, b; struct Node; struct Edge{ //krawedz z v->w int v, w; // Edge *rev; // wska?nik na kraw?d? wsteczn? Edge(int v_in, int w_in){ v = v_in; w = w_in; } }; struct Vertex{ int x = 0; //numer wierzcholka int tory = 0; // int bfs = 0; bool visited; Vertex(int x_in){ visited = false; x = x_in; } Vertex(){ visited = false; x = 0; } }; Vertex V[MAX+1]; std::list<Edge> E[MAX+1]; int vec[MAX+1]; long long int dfs(int v, int size){ V[v].visited = true; int mysize = size; long long int sum = 0; if(v > m){ for (std::list<Edge>::iterator j = E[v].begin(); j != E[v].end(); ++j) { if(!(V[j->w].visited)){ sum+=dfs(j->w, mysize+1); vec[mysize] = V[j->w].tory; mysize++; } } std::sort(vec+size,vec+mysize); // printf("\n"); int med = 0; int vsize = (mysize-size); if(vsize % 2){ med = vec[size+vsize/2]; // printf("index: %d, med1: %d\n", index, med); }else{ med = vec[size+(vsize/2)-1]; // printf("index: %d, med2: %d\n", index, med); } // printf("index: %d, med: %d\n", index, med); for(int j = size; j < mysize; ++j){ if(med > vec[j]){ sum+= med-vec[j]; }else{ sum+= vec[j]-med; } } V[v].tory = med; } //printf("index: , suma"); return sum; } int main(){ scanf("%d %d", &n, &m); for(int i = 1; i <= n; ++i){ V[i] = Vertex(i); } for(int i = 1; i < n; ++i){ scanf("%d %d", &a, &b); E[a].push_back(Edge(a,b)); E[b].push_back(Edge(b,a)); // E[a].back().rev = &E[b].back(); // E[b].back().rev = &E[a].back(); } for(int i = 1; i <= m; ++i){ scanf("%d", &a); V[i].tory = a; } long long int suma = 0; if(n>m){ suma = dfs(m+1,0); } for(int i = 1; i <= m; ++i){ if(!V[i].visited){ long long int min = V[i].tory; long long int max = V[i].tory; for (std::list<Edge>::iterator j = E[i].begin(); j != E[i].end(); ++j) { if(!(V[j->w].visited)){ if(min > V[j->w].tory ) min = V[j->w].tory; if(max < V[j->w].tory ) max = V[j->w].tory; V[j->w].visited = true; } } suma+=max-min; } } printf("%lld", suma); //printf("\n"); //for(int i = 1; i <= n; ++i){ // printf(" i: %d, bfs: %d\n", i, V[i].bfs); // } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | #include <cstdio> #include <list> #include <queue> #include <vector> #include <algorithm> #include <cmath> const int MAX = 500001; int n, m, d, a, b; struct Node; struct Edge{ //krawedz z v->w int v, w; // Edge *rev; // wska?nik na kraw?d? wsteczn? Edge(int v_in, int w_in){ v = v_in; w = w_in; } }; struct Vertex{ int x = 0; //numer wierzcholka int tory = 0; // int bfs = 0; bool visited; Vertex(int x_in){ visited = false; x = x_in; } Vertex(){ visited = false; x = 0; } }; Vertex V[MAX+1]; std::list<Edge> E[MAX+1]; int vec[MAX+1]; long long int dfs(int v, int size){ V[v].visited = true; int mysize = size; long long int sum = 0; if(v > m){ for (std::list<Edge>::iterator j = E[v].begin(); j != E[v].end(); ++j) { if(!(V[j->w].visited)){ sum+=dfs(j->w, mysize+1); vec[mysize] = V[j->w].tory; mysize++; } } std::sort(vec+size,vec+mysize); // printf("\n"); int med = 0; int vsize = (mysize-size); if(vsize % 2){ med = vec[size+vsize/2]; // printf("index: %d, med1: %d\n", index, med); }else{ med = vec[size+(vsize/2)-1]; // printf("index: %d, med2: %d\n", index, med); } // printf("index: %d, med: %d\n", index, med); for(int j = size; j < mysize; ++j){ if(med > vec[j]){ sum+= med-vec[j]; }else{ sum+= vec[j]-med; } } V[v].tory = med; } //printf("index: , suma"); return sum; } int main(){ scanf("%d %d", &n, &m); for(int i = 1; i <= n; ++i){ V[i] = Vertex(i); } for(int i = 1; i < n; ++i){ scanf("%d %d", &a, &b); E[a].push_back(Edge(a,b)); E[b].push_back(Edge(b,a)); // E[a].back().rev = &E[b].back(); // E[b].back().rev = &E[a].back(); } for(int i = 1; i <= m; ++i){ scanf("%d", &a); V[i].tory = a; } long long int suma = 0; if(n>m){ suma = dfs(m+1,0); } for(int i = 1; i <= m; ++i){ if(!V[i].visited){ long long int min = V[i].tory; long long int max = V[i].tory; for (std::list<Edge>::iterator j = E[i].begin(); j != E[i].end(); ++j) { if(!(V[j->w].visited)){ if(min > V[j->w].tory ) min = V[j->w].tory; if(max < V[j->w].tory ) max = V[j->w].tory; V[j->w].visited = true; } } suma+=max-min; } } printf("%lld", suma); //printf("\n"); //for(int i = 1; i <= n; ++i){ // printf(" i: %d, bfs: %d\n", i, V[i].bfs); // } } |