#include <bits/stdc++.h> using namespace std; typedef complex<double> base; const double pi=acos(-1.); const int MID=61000,MX=505; int n,tot,i,m,z,a[MX]; long long s[MX],mn,mx,all[MX*MX]; vector<base> neg,negc,pos,posc; long long cubic(long long x) { return (x*(x-1)*(x-2))/6LL; } long long llv(const complex<double>& x) { return (long long)(x.real()+0.5); } void fft(vector<base>& a, bool inv) { int n=a.size(); if (n==1) return; for (int i=1, rv=0; i<n; i++) { int j=n/2; for (; rv>=j; j/=2) rv-=j; rv+=j; if (i<rv) swap(a[i],a[rv]); } for (int l=2; l<=n; l*=2) { double alpha=2*pi/l; if (inv) alpha=-alpha; for (int i=0; i<n; i+=l) { base w(1), wn(cos(alpha),sin(alpha)); for (int j=0; j<l/2; j++) { base x=a[i+j]; base y=w*a[i+j+l/2]; a[i+j]=x+y; a[i+j+l/2]=x-y; w*=wn; } } } if (inv) for (int i=0; i<n; i++) a[i]/=n; } void mul(vector<base>& a, vector<base>& b) { int n; for (n=1; n<a.size() || n<b.size(); n*=2); n*=2; a.resize(n); b.resize(n); fft(a,false); fft(b,false); for (int i=0; i<n; i++) a[i]*=b[i]; fft(a,true); } int main() { scanf("%d",&n); for (int i=1; i<=n; i++) { scanf("%d",&a[i]); s[i]=s[i-1]+a[i]; for (int j=0; j<i; j++) { long long val=s[i]-s[j]; if (val) { mn=min(mn,val); mx=max(mx,val); all[m++]=val; } else ++z; } } tot=min(-mn+1,mx+1); neg.resize(tot); pos.resize(tot); for (int i=0; i<m; i++) if (all[i]<0) { if (-all[i]<tot) neg[-all[i]]=llv(neg[-all[i]])+1; } else { if (all[i]<tot) pos[all[i]]=llv(pos[all[i]])+1; } long long res=cubic(z); if (z) for (int i=0; i<m; i++) if (all[i]>0 && all[i]<tot) res+=z*llv(neg[all[i]]); negc=neg; mul(neg,negc); posc=pos; mul(pos,posc); for (int i=0; i<m; i++) if (all[i]<0) { int idx=-all[i]-all[i]; if (idx<tot) neg[idx]=llv(neg[idx])-1; } else { int idx=all[i]+all[i]; if (idx<tot) pos[idx]=llv(pos[idx])-1; } for (int i=0; i<m; i++) if (all[i]<0) { if (-all[i]<tot) res+=llv(pos[-all[i]])/2; } else { if (all[i]<tot) res+=llv(neg[all[i]])/2; } printf("%lld\n",res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | #include <bits/stdc++.h> using namespace std; typedef complex<double> base; const double pi=acos(-1.); const int MID=61000,MX=505; int n,tot,i,m,z,a[MX]; long long s[MX],mn,mx,all[MX*MX]; vector<base> neg,negc,pos,posc; long long cubic(long long x) { return (x*(x-1)*(x-2))/6LL; } long long llv(const complex<double>& x) { return (long long)(x.real()+0.5); } void fft(vector<base>& a, bool inv) { int n=a.size(); if (n==1) return; for (int i=1, rv=0; i<n; i++) { int j=n/2; for (; rv>=j; j/=2) rv-=j; rv+=j; if (i<rv) swap(a[i],a[rv]); } for (int l=2; l<=n; l*=2) { double alpha=2*pi/l; if (inv) alpha=-alpha; for (int i=0; i<n; i+=l) { base w(1), wn(cos(alpha),sin(alpha)); for (int j=0; j<l/2; j++) { base x=a[i+j]; base y=w*a[i+j+l/2]; a[i+j]=x+y; a[i+j+l/2]=x-y; w*=wn; } } } if (inv) for (int i=0; i<n; i++) a[i]/=n; } void mul(vector<base>& a, vector<base>& b) { int n; for (n=1; n<a.size() || n<b.size(); n*=2); n*=2; a.resize(n); b.resize(n); fft(a,false); fft(b,false); for (int i=0; i<n; i++) a[i]*=b[i]; fft(a,true); } int main() { scanf("%d",&n); for (int i=1; i<=n; i++) { scanf("%d",&a[i]); s[i]=s[i-1]+a[i]; for (int j=0; j<i; j++) { long long val=s[i]-s[j]; if (val) { mn=min(mn,val); mx=max(mx,val); all[m++]=val; } else ++z; } } tot=min(-mn+1,mx+1); neg.resize(tot); pos.resize(tot); for (int i=0; i<m; i++) if (all[i]<0) { if (-all[i]<tot) neg[-all[i]]=llv(neg[-all[i]])+1; } else { if (all[i]<tot) pos[all[i]]=llv(pos[all[i]])+1; } long long res=cubic(z); if (z) for (int i=0; i<m; i++) if (all[i]>0 && all[i]<tot) res+=z*llv(neg[all[i]]); negc=neg; mul(neg,negc); posc=pos; mul(pos,posc); for (int i=0; i<m; i++) if (all[i]<0) { int idx=-all[i]-all[i]; if (idx<tot) neg[idx]=llv(neg[idx])-1; } else { int idx=all[i]+all[i]; if (idx<tot) pos[idx]=llv(pos[idx])-1; } for (int i=0; i<m; i++) if (all[i]<0) { if (-all[i]<tot) res+=llv(pos[-all[i]])/2; } else { if (all[i]<tot) res+=llv(neg[all[i]])/2; } printf("%lld\n",res); return 0; } |