#include<cstdio> #include<cstdlib> #include<cmath> #include<cstring> #include<cassert> #include<iostream> #include<algorithm> #include<queue> #include<stack> #include<bitset> #include<set> #include<map> #define REP(i,n) for(int i=0;i<(n);i++) #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define FORD(i,a,b) for(int i=(a);i>=(b);i--) #define foreach(i,c) for(__typeof((c).begin())i=(c).begin();i!=(c).end();i++) #define all(c) (c).begin(),(c).end() #define scanf(...) scanf(__VA_ARGS__)?:0 #define eprintf(...) fprintf(stderr,__VA_ARGS__),fflush(stderr) #define e1 first #define e2 second #define mp make_pair #define pb push_back #define eb emplace_back #define infLL 1000000000000000023ll using namespace std; typedef long long ll; typedef long double ld; typedef unsigned int uint; typedef unsigned long long ull; typedef pair<int,int> pii; typedef pair<ll,ll> pll; typedef pair<ll,int> pli; typedef pair<int,ll> pil; int n,m,a,b,r[500001]; ll wyn; pii prz[500001]; vector<int> v[500001]; bitset<500001> odw; void dfswyn(int a) { odw[a]=true; if (a<=m) return; vector<pii> s; vector<int> u,ilewo,iprawo; foreach(it,v[a]) if (!odw[*it]) dfswyn(*it),s.pb(prz[*it]); foreach(it,s) u.pb(it->e1),u.pb(it->e2); sort(all(u)); u.erase(unique(all(u)),u.end()); ilewo.resize(u.size()); iprawo.resize(u.size()); int p=u.size(); foreach(it,s) { int lewy=lower_bound(all(u),it->e1)-u.begin(),prawy=lower_bound(all(u),it->e2)-u.begin(); ilewo[prawy]++; iprawo[lewy]++; } FOR(i,1,p-1) ilewo[i]+=ilewo[i-1]; FORD(i,p-2,0) iprawo[i]+=iprawo[i+1]; ll wl=0,wp=0,w=infLL; foreach(it,s) wp+=it->e1-u[0]; FOR(i,0,p-1) { w=min(w,wl+wp); if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]); } wl=0; wp=0; int lk=-1,pk=-1; foreach(it,s) wp+=it->e1-u[0]; FOR(i,0,p-1) { if (wl+wp==w) { if (lk==-1) lk=i; pk=i; } if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]); } prz[a]=mp(u[lk],u[pk]); wyn+=w; } int main() { scanf("%d%d",&n,&m); REP(i,n-1) scanf("%d%d",&a,&b),v[a].pb(b),v[b].pb(a); FOR(i,1,m) scanf("%d",&r[i]),prz[i]=mp(r[i],r[i]); if (n==m) assert(n==2),printf("%d",abs(r[1]-r[2])),exit(0); dfswyn(m+1); printf("%lld\n",wyn); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | #include<cstdio> #include<cstdlib> #include<cmath> #include<cstring> #include<cassert> #include<iostream> #include<algorithm> #include<queue> #include<stack> #include<bitset> #include<set> #include<map> #define REP(i,n) for(int i=0;i<(n);i++) #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define FORD(i,a,b) for(int i=(a);i>=(b);i--) #define foreach(i,c) for(__typeof((c).begin())i=(c).begin();i!=(c).end();i++) #define all(c) (c).begin(),(c).end() #define scanf(...) scanf(__VA_ARGS__)?:0 #define eprintf(...) fprintf(stderr,__VA_ARGS__),fflush(stderr) #define e1 first #define e2 second #define mp make_pair #define pb push_back #define eb emplace_back #define infLL 1000000000000000023ll using namespace std; typedef long long ll; typedef long double ld; typedef unsigned int uint; typedef unsigned long long ull; typedef pair<int,int> pii; typedef pair<ll,ll> pll; typedef pair<ll,int> pli; typedef pair<int,ll> pil; int n,m,a,b,r[500001]; ll wyn; pii prz[500001]; vector<int> v[500001]; bitset<500001> odw; void dfswyn(int a) { odw[a]=true; if (a<=m) return; vector<pii> s; vector<int> u,ilewo,iprawo; foreach(it,v[a]) if (!odw[*it]) dfswyn(*it),s.pb(prz[*it]); foreach(it,s) u.pb(it->e1),u.pb(it->e2); sort(all(u)); u.erase(unique(all(u)),u.end()); ilewo.resize(u.size()); iprawo.resize(u.size()); int p=u.size(); foreach(it,s) { int lewy=lower_bound(all(u),it->e1)-u.begin(),prawy=lower_bound(all(u),it->e2)-u.begin(); ilewo[prawy]++; iprawo[lewy]++; } FOR(i,1,p-1) ilewo[i]+=ilewo[i-1]; FORD(i,p-2,0) iprawo[i]+=iprawo[i+1]; ll wl=0,wp=0,w=infLL; foreach(it,s) wp+=it->e1-u[0]; FOR(i,0,p-1) { w=min(w,wl+wp); if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]); } wl=0; wp=0; int lk=-1,pk=-1; foreach(it,s) wp+=it->e1-u[0]; FOR(i,0,p-1) { if (wl+wp==w) { if (lk==-1) lk=i; pk=i; } if (i<p-1) wl+=(ll)ilewo[i]*(u[i+1]-u[i]),wp-=(ll)iprawo[i+1]*(u[i+1]-u[i]); } prz[a]=mp(u[lk],u[pk]); wyn+=w; } int main() { scanf("%d%d",&n,&m); REP(i,n-1) scanf("%d%d",&a,&b),v[a].pb(b),v[b].pb(a); FOR(i,1,m) scanf("%d",&r[i]),prz[i]=mp(r[i],r[i]); if (n==m) assert(n==2),printf("%d",abs(r[1]-r[2])),exit(0); dfswyn(m+1); printf("%lld\n",wyn); } |