#include <bits/stdc++.h> using namespace std; #define fru(j,n) for(int j=0; j<(n); ++j) #define tr(it,v) for(auto it=(v).begin(); it!=(v).end(); ++it) #define x first #define y second #define pb push_back #define mp make_pair #define ALL(G) (G).begin(),(G).end() typedef long long ll; typedef double D; typedef pair<int,int> pii; typedef vector<int> vi; const int inft = 1000000009; const int MAXN = 1000006; vi V[MAXN]; int A[MAXN],n,m; ll ans; pii dfs(int a,int s){ if(a<m)return pii(A[a],A[a]); vi P; ll przedz=0; tr(it,V[a])if(*it!=s){ pii u=dfs(*it,a); P.pb(u.x);P.pb(u.y); przedz+=u.y-u.x; } int k=P.size(); pii u; sort(ALL(P)); u.x=P[(k-1)/2];u.y=P[k/2]; ll ret=0; tr(it,P)ret+=abs(u.x-*it); ret-=przedz; ret/=2; ans+=ret; return u; } void solve() { scanf("%d%d",&n,&m); fru(i,n-1){ int a,b; scanf("%d%d",&a,&b);a--;b--; V[a].pb(b); V[b].pb(a); } fru(i,m)scanf("%d",&A[i]); if(n==2 && m==2){ printf("%d\n",abs(A[1]-A[0]));return; } ans=0; dfs(n-1,-1); cout<<ans<<endl; } int main() { // freopen("input.in", "r", stdin); // freopen("output.out", "w", stdout); int t=1; // scanf("%d",&t); fru(i,t) solve(); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | #include <bits/stdc++.h> using namespace std; #define fru(j,n) for(int j=0; j<(n); ++j) #define tr(it,v) for(auto it=(v).begin(); it!=(v).end(); ++it) #define x first #define y second #define pb push_back #define mp make_pair #define ALL(G) (G).begin(),(G).end() typedef long long ll; typedef double D; typedef pair<int,int> pii; typedef vector<int> vi; const int inft = 1000000009; const int MAXN = 1000006; vi V[MAXN]; int A[MAXN],n,m; ll ans; pii dfs(int a,int s){ if(a<m)return pii(A[a],A[a]); vi P; ll przedz=0; tr(it,V[a])if(*it!=s){ pii u=dfs(*it,a); P.pb(u.x);P.pb(u.y); przedz+=u.y-u.x; } int k=P.size(); pii u; sort(ALL(P)); u.x=P[(k-1)/2];u.y=P[k/2]; ll ret=0; tr(it,P)ret+=abs(u.x-*it); ret-=przedz; ret/=2; ans+=ret; return u; } void solve() { scanf("%d%d",&n,&m); fru(i,n-1){ int a,b; scanf("%d%d",&a,&b);a--;b--; V[a].pb(b); V[b].pb(a); } fru(i,m)scanf("%d",&A[i]); if(n==2 && m==2){ printf("%d\n",abs(A[1]-A[0]));return; } ans=0; dfs(n-1,-1); cout<<ans<<endl; } int main() { // freopen("input.in", "r", stdin); // freopen("output.out", "w", stdout); int t=1; // scanf("%d",&t); fru(i,t) solve(); return 0; } |