#include<bits/stdc++.h>
using namespace std;
vector<int> kraw[500001];
long long wagi[500001];
int odw[500001];
int ojc[500001];
long long pr[500001][2];
long long wyn=0;
void dfs(int x)
{
odw[x]=1;
long long sum=0;
long long no=0,za=0, po=0, ko=0, naj;
if(kraw[x].size()==1)
{
pr[x][0]=wagi[x];
pr[x][1]=wagi[x];
return;
}
vector<pair <long long , int> > tab;
for(vector<int> :: iterator it=kraw[x].begin(); it!=kraw[x].end(); it++)
if(odw[*it]==0)
{
dfs(*it);
sum+=pr[*it][0];
no++;
tab.push_back(make_pair(pr[*it][0], 1));
tab.push_back(make_pair(pr[*it][1], -1));
}
sort(tab.begin(), tab.end());
naj=sum;
long long ost=0;
for(int i=0; i<tab.size(); i++){
long long y=tab[i].first-ost;
//if(x==6)
// printf("? %lld %d\n", tab[i].first, tab[i].second);
ost=tab[i].first;
sum+=-no*y+za*y;
if(sum<naj)
{
naj=sum;
po=ost;
}
if(sum<=naj)
ko=ost;
if(tab[i].second == 1)
{
no--;
}
if(tab[i].second == -1)
{
za++;
}
//if(x==6)
// printf("%lld %lld %lld %lld\n", ost, sum, no, za);
}
wyn+=naj;
pr[x][0]=po;
pr[x][1]=ko;
return;
}
int main()
{
int n,m;
scanf("%d%d", &n, &m);
for(int i=1; i<n; i++)
{
int x,y;
scanf("%d%d", &x, &y);
kraw[x].push_back(y);
kraw[y].push_back(x);
}
for(int i=1; i<=m; i++)
scanf("%lld", &wagi[i]);
if(n==2)
{
if(wagi[1]<wagi[2])
swap(wagi[1], wagi[2]);
printf("%lld\n", wagi[1]-wagi[2]);
return 0;
}
dfs(n);
printf("%lld\n", wyn);
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #include<bits/stdc++.h> using namespace std; vector<int> kraw[500001]; long long wagi[500001]; int odw[500001]; int ojc[500001]; long long pr[500001][2]; long long wyn=0; void dfs(int x) { odw[x]=1; long long sum=0; long long no=0,za=0, po=0, ko=0, naj; if(kraw[x].size()==1) { pr[x][0]=wagi[x]; pr[x][1]=wagi[x]; return; } vector<pair <long long , int> > tab; for(vector<int> :: iterator it=kraw[x].begin(); it!=kraw[x].end(); it++) if(odw[*it]==0) { dfs(*it); sum+=pr[*it][0]; no++; tab.push_back(make_pair(pr[*it][0], 1)); tab.push_back(make_pair(pr[*it][1], -1)); } sort(tab.begin(), tab.end()); naj=sum; long long ost=0; for(int i=0; i<tab.size(); i++){ long long y=tab[i].first-ost; //if(x==6) // printf("? %lld %d\n", tab[i].first, tab[i].second); ost=tab[i].first; sum+=-no*y+za*y; if(sum<naj) { naj=sum; po=ost; } if(sum<=naj) ko=ost; if(tab[i].second == 1) { no--; } if(tab[i].second == -1) { za++; } //if(x==6) // printf("%lld %lld %lld %lld\n", ost, sum, no, za); } wyn+=naj; pr[x][0]=po; pr[x][1]=ko; return; } int main() { int n,m; scanf("%d%d", &n, &m); for(int i=1; i<n; i++) { int x,y; scanf("%d%d", &x, &y); kraw[x].push_back(y); kraw[y].push_back(x); } for(int i=1; i<=m; i++) scanf("%lld", &wagi[i]); if(n==2) { if(wagi[1]<wagi[2]) swap(wagi[1], wagi[2]); printf("%lld\n", wagi[1]-wagi[2]); return 0; } dfs(n); printf("%lld\n", wyn); } |
polski