#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define all(a) a.begin(),a.end()
#define rall(a) a.rbegin(),a.rend()
#define sz(a) (int)a.size()
#define pb push_back
#define mp make_pair
#define st first
#define nd second
#define endl '\n'
#define fast ios_base::sync_with_stdio(0);cin.tie(0);
#define vc vector
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pulul pair<ull,ull>
void solve()
{
int k;
cin >> k;
vc<int> n(k+1);
cin >> n[1];
vc<vc<int>> par(k+1);
for(int i=2; i<=k; ++i)
{
cin >> n[i];
par[i].assign(n[i]+1,0);
for(int j=1; j<=n[i]; ++j)
cin >> par[i][j];
}
vc<ll> dp(n[k]+1,1);
ll ans=n[k];
for(int i=k-1; i>=1; --i)
{
vc<ll> sum(n[i]+1,0);
for(int j=1; j<=n[i+1]; ++j)
{
if(par[i+1][j]>0)
sum[par[i+1][j]]+=dp[j];
}
vc<ll> ndp(n[i]+1,0);
ll cur=0;
for(int j=1; j<=n[i]; ++j)
{
ndp[j]=max(1LL, sum[j]);
cur+=ndp[j];
}
ans=max(ans,cur);
dp.swap(ndp);
}
cout << ans;
}
int main()
{
fast;
int t=1;
//cin >> t;
while(t--)
solve();
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; #define all(a) a.begin(),a.end() #define rall(a) a.rbegin(),a.rend() #define sz(a) (int)a.size() #define pb push_back #define mp make_pair #define st first #define nd second #define endl '\n' #define fast ios_base::sync_with_stdio(0);cin.tie(0); #define vc vector #define pii pair<int,int> #define pll pair<ll,ll> #define pulul pair<ull,ull> void solve() { int k; cin >> k; vc<int> n(k+1); cin >> n[1]; vc<vc<int>> par(k+1); for(int i=2; i<=k; ++i) { cin >> n[i]; par[i].assign(n[i]+1,0); for(int j=1; j<=n[i]; ++j) cin >> par[i][j]; } vc<ll> dp(n[k]+1,1); ll ans=n[k]; for(int i=k-1; i>=1; --i) { vc<ll> sum(n[i]+1,0); for(int j=1; j<=n[i+1]; ++j) { if(par[i+1][j]>0) sum[par[i+1][j]]+=dp[j]; } vc<ll> ndp(n[i]+1,0); ll cur=0; for(int j=1; j<=n[i]; ++j) { ndp[j]=max(1LL, sum[j]); cur+=ndp[j]; } ans=max(ans,cur); dp.swap(ndp); } cout << ans; } int main() { fast; int t=1; //cin >> t; while(t--) solve(); return 0; } |
English