#include <bits/stdc++.h>
using namespace std;
#define fru(j,n) for(int j=0; j<(n); ++j)
#define tr(it,v) for(auto it=(v).begin(); it!=(v).end(); ++it)
#define x first
#define y second
#define pb push_back
#define mp make_pair
#define ALL(G) (G).begin(),(G).end()
typedef long long ll;
typedef double D;
typedef pair<int,int> pii;
typedef vector<int> vi;
const int inft = 1000000009;
const int MAXN = 1000006;
vi V[MAXN];
int A[MAXN],n,m;
ll ans;
pii dfs(int a,int s){
if(a<m)return pii(A[a],A[a]);
vi P;
ll przedz=0;
tr(it,V[a])if(*it!=s){
pii u=dfs(*it,a);
P.pb(u.x);P.pb(u.y);
przedz+=u.y-u.x;
}
int k=P.size();
pii u;
sort(ALL(P));
u.x=P[(k-1)/2];u.y=P[k/2];
ll ret=0;
tr(it,P)ret+=abs(u.x-*it);
ret-=przedz;
ret/=2;
ans+=ret;
return u;
}
void solve() {
scanf("%d%d",&n,&m);
fru(i,n-1){
int a,b;
scanf("%d%d",&a,&b);a--;b--;
V[a].pb(b);
V[b].pb(a);
}
fru(i,m)scanf("%d",&A[i]);
if(n==2 && m==2){
printf("%d\n",abs(A[1]-A[0]));return;
}
ans=0;
dfs(n-1,-1);
cout<<ans<<endl;
}
int main() {
// freopen("input.in", "r", stdin);
// freopen("output.out", "w", stdout);
int t=1;
// scanf("%d",&t);
fru(i,t) solve();
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | #include <bits/stdc++.h> using namespace std; #define fru(j,n) for(int j=0; j<(n); ++j) #define tr(it,v) for(auto it=(v).begin(); it!=(v).end(); ++it) #define x first #define y second #define pb push_back #define mp make_pair #define ALL(G) (G).begin(),(G).end() typedef long long ll; typedef double D; typedef pair<int,int> pii; typedef vector<int> vi; const int inft = 1000000009; const int MAXN = 1000006; vi V[MAXN]; int A[MAXN],n,m; ll ans; pii dfs(int a,int s){ if(a<m)return pii(A[a],A[a]); vi P; ll przedz=0; tr(it,V[a])if(*it!=s){ pii u=dfs(*it,a); P.pb(u.x);P.pb(u.y); przedz+=u.y-u.x; } int k=P.size(); pii u; sort(ALL(P)); u.x=P[(k-1)/2];u.y=P[k/2]; ll ret=0; tr(it,P)ret+=abs(u.x-*it); ret-=przedz; ret/=2; ans+=ret; return u; } void solve() { scanf("%d%d",&n,&m); fru(i,n-1){ int a,b; scanf("%d%d",&a,&b);a--;b--; V[a].pb(b); V[b].pb(a); } fru(i,m)scanf("%d",&A[i]); if(n==2 && m==2){ printf("%d\n",abs(A[1]-A[0]));return; } ans=0; dfs(n-1,-1); cout<<ans<<endl; } int main() { // freopen("input.in", "r", stdin); // freopen("output.out", "w", stdout); int t=1; // scanf("%d",&t); fru(i,t) solve(); return 0; } |
English