#include <bits/stdc++.h> using namespace std; typedef pair<int,int> pii; const int MX=200200; const long long inf=-2100100100100100100LL; int n,i,x,y,z,cnt[MX],dep[MX],mid[MX]; vector<pii> g[MX]; vector<long long> dp[MX][2]; long long f[MX][4]; void upd(long long& f, long long v) { f=max(f,v); } void dfs(int i, int p) { dep[i]=1; for (int j=0; j<g[i].size(); j++) { int k=g[i][j].first; if (k==p) continue; dfs(k,i); dep[i]=max(dep[i],dep[k]+1); if (dep[k]>=3) ++cnt[i]; } } void solve(int i, int p, int prw) { int prv=0; int rst=g[i].size(); if (p>0) --rst; for (int j=0; j<g[i].size(); j++) { int k=g[i][j].first; if (k==p) continue; int w=g[i][j].second; solve(k,i,w); --rst; int lft=-mid[prv]; int rgh=int(dp[prv][0].size())-1-mid[prv]; int nlft=max(lft-int(dep[k]>=3),-rst-1); int nrgh=min(rgh+1,rst+1); mid[k]=-nlft; dp[k][0].assign(nrgh-nlft+1,inf); dp[k][1].assign(nrgh-nlft+1,inf); for (int c=0; c<2; ++c) { auto it=dp[prv][c].begin(); for (int d=lft; d<=rgh; ++d, ++it) { if ((*it)==inf) continue; if (d>=nlft && d<=nrgh) { upd(dp[k][c][-nlft+d],(*it)+f[k][0]); if (f[k][2]!=inf) upd(dp[k][c^1][-nlft+d],(*it)+f[k][2]); } if (d<nrgh) upd(dp[k][c][-nlft+d+1],(*it)+f[k][1]); if (d>nlft && f[k][3]!=inf) upd(dp[k][c][-nlft+d-1],(*it)+f[k][3]); } } prv=k; } f[i][0]=f[i][1]=dp[prv][0][mid[prv]]; if (mid[prv]>=1 && dp[prv][0][mid[prv]-1]!=inf && p>0) f[i][0]=max(f[i][0],dp[prv][0][mid[prv]-1]+prw); if (mid[prv]+1<dp[prv][0].size()) f[i][2]=dp[prv][0][mid[prv]+1]; else f[i][2]=inf; f[i][3]=dp[prv][1][mid[prv]]; if (p>0) for (int j=1; j<=3; j++) if (f[i][j]!=inf) f[i][j]+=prw; //printf("%d:\n",i); //for (int j=0; j<=3; j++) printf("f %d %d = %lld\n",i,j,f[i][j]); } bool cmp(const pii& x, const pii& y) { return dep[x.first]<dep[y.first]; } int main() { scanf("%d",&n); for (i=1; i<n; i++) { scanf("%d%d%d",&x,&y,&z); g[x].emplace_back(y,z); g[y].emplace_back(x,z); } for (x=i=1; i<n; i++) if (g[i].size()>g[x].size()) x=i; dfs(x,0); for (i=1; i<=n; i++) sort(g[i].begin(),g[i].end(),cmp); mid[0]=0; dp[0][0].assign(1,0); dp[0][1].assign(1,inf); solve(x,0,0); printf("%lld\n",f[x][0]); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | #include <bits/stdc++.h> using namespace std; typedef pair<int,int> pii; const int MX=200200; const long long inf=-2100100100100100100LL; int n,i,x,y,z,cnt[MX],dep[MX],mid[MX]; vector<pii> g[MX]; vector<long long> dp[MX][2]; long long f[MX][4]; void upd(long long& f, long long v) { f=max(f,v); } void dfs(int i, int p) { dep[i]=1; for (int j=0; j<g[i].size(); j++) { int k=g[i][j].first; if (k==p) continue; dfs(k,i); dep[i]=max(dep[i],dep[k]+1); if (dep[k]>=3) ++cnt[i]; } } void solve(int i, int p, int prw) { int prv=0; int rst=g[i].size(); if (p>0) --rst; for (int j=0; j<g[i].size(); j++) { int k=g[i][j].first; if (k==p) continue; int w=g[i][j].second; solve(k,i,w); --rst; int lft=-mid[prv]; int rgh=int(dp[prv][0].size())-1-mid[prv]; int nlft=max(lft-int(dep[k]>=3),-rst-1); int nrgh=min(rgh+1,rst+1); mid[k]=-nlft; dp[k][0].assign(nrgh-nlft+1,inf); dp[k][1].assign(nrgh-nlft+1,inf); for (int c=0; c<2; ++c) { auto it=dp[prv][c].begin(); for (int d=lft; d<=rgh; ++d, ++it) { if ((*it)==inf) continue; if (d>=nlft && d<=nrgh) { upd(dp[k][c][-nlft+d],(*it)+f[k][0]); if (f[k][2]!=inf) upd(dp[k][c^1][-nlft+d],(*it)+f[k][2]); } if (d<nrgh) upd(dp[k][c][-nlft+d+1],(*it)+f[k][1]); if (d>nlft && f[k][3]!=inf) upd(dp[k][c][-nlft+d-1],(*it)+f[k][3]); } } prv=k; } f[i][0]=f[i][1]=dp[prv][0][mid[prv]]; if (mid[prv]>=1 && dp[prv][0][mid[prv]-1]!=inf && p>0) f[i][0]=max(f[i][0],dp[prv][0][mid[prv]-1]+prw); if (mid[prv]+1<dp[prv][0].size()) f[i][2]=dp[prv][0][mid[prv]+1]; else f[i][2]=inf; f[i][3]=dp[prv][1][mid[prv]]; if (p>0) for (int j=1; j<=3; j++) if (f[i][j]!=inf) f[i][j]+=prw; //printf("%d:\n",i); //for (int j=0; j<=3; j++) printf("f %d %d = %lld\n",i,j,f[i][j]); } bool cmp(const pii& x, const pii& y) { return dep[x.first]<dep[y.first]; } int main() { scanf("%d",&n); for (i=1; i<n; i++) { scanf("%d%d%d",&x,&y,&z); g[x].emplace_back(y,z); g[y].emplace_back(x,z); } for (x=i=1; i<n; i++) if (g[i].size()>g[x].size()) x=i; dfs(x,0); for (i=1; i<=n; i++) sort(g[i].begin(),g[i].end(),cmp); mid[0]=0; dp[0][0].assign(1,0); dp[0][1].assign(1,inf); solve(x,0,0); printf("%lld\n",f[x][0]); return 0; } |