#include<bits/stdc++.h> using namespace std; typedef pair<int,int> PI; typedef long long LL; typedef double D; #define FI first #define SE second #define MP make_pair #define PB push_back #define R(I,N) for(int I=0;I<N;I++) #define F(I,A,B) for(int I=A;I<B;I++) #define FD(I,N) for(int I=N-1;I>=0;I--) #define make(A) scanf("%d",&A) #define make2(A,B) scanf("%d%d",&A,&B) #define ALL(x) (x).begin(), (x).end() #define SZ(x) ((int)(x).size()) #define db if(1)printf template<typename C> void MA(C& a,C b){if(a<b)a=b;} template<typename C> void MI(C& a,C b){if(a>b)a=b;} #define MAX 501001 int n,m; vector<int> d[MAX]; int kol[MAX]; LL wyn = 0; pair<int,int> dfs(int nr,int oj){ if(nr < m && nr != 0)return {kol[nr],kol[nr]}; vector<int> pkt; for(int x:d[nr]){ if(x != oj){ PI pom = dfs(x,nr); pkt.PB(pom.FI); pkt.PB(pom.SE); } } int d = SZ(pkt)/2; nth_element(pkt.begin(),pkt.begin()+d-1,pkt.end()); R(i,d)wyn += pkt[SZ(pkt)-i-1] - pkt[i]; int pom = *min_element(pkt.begin()+d,pkt.end()); wyn -= pom - pkt[d-1]; return {pkt[d-1],pom}; } main(){ make2(n,m); R(i,n-1){ int a,b; make2(a,b); a--;b--; d[a].PB(b); d[b].PB(a); } R(i,m)make(kol[i]); PI x = dfs(0,-1); wyn += abs(x.FI - kol[0]); wyn += abs(x.SE - kol[0]); printf("%lld\n",wyn/2); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | #include<bits/stdc++.h> using namespace std; typedef pair<int,int> PI; typedef long long LL; typedef double D; #define FI first #define SE second #define MP make_pair #define PB push_back #define R(I,N) for(int I=0;I<N;I++) #define F(I,A,B) for(int I=A;I<B;I++) #define FD(I,N) for(int I=N-1;I>=0;I--) #define make(A) scanf("%d",&A) #define make2(A,B) scanf("%d%d",&A,&B) #define ALL(x) (x).begin(), (x).end() #define SZ(x) ((int)(x).size()) #define db if(1)printf template<typename C> void MA(C& a,C b){if(a<b)a=b;} template<typename C> void MI(C& a,C b){if(a>b)a=b;} #define MAX 501001 int n,m; vector<int> d[MAX]; int kol[MAX]; LL wyn = 0; pair<int,int> dfs(int nr,int oj){ if(nr < m && nr != 0)return {kol[nr],kol[nr]}; vector<int> pkt; for(int x:d[nr]){ if(x != oj){ PI pom = dfs(x,nr); pkt.PB(pom.FI); pkt.PB(pom.SE); } } int d = SZ(pkt)/2; nth_element(pkt.begin(),pkt.begin()+d-1,pkt.end()); R(i,d)wyn += pkt[SZ(pkt)-i-1] - pkt[i]; int pom = *min_element(pkt.begin()+d,pkt.end()); wyn -= pom - pkt[d-1]; return {pkt[d-1],pom}; } main(){ make2(n,m); R(i,n-1){ int a,b; make2(a,b); a--;b--; d[a].PB(b); d[b].PB(a); } R(i,m)make(kol[i]); PI x = dfs(0,-1); wyn += abs(x.FI - kol[0]); wyn += abs(x.SE - kol[0]); printf("%lld\n",wyn/2); } |