#include<bits/stdc++.h> using namespace std; using lld = long long; const lld mod = 1e9+7; struct ST { struct Node { Node* left = nullptr; Node* right = nullptr; int value = 0; ~Node() { delete left; delete right; } void update() { value = ((left == nullptr ? 0 : left->value) + (right == nullptr ? 0 : right->value)) % mod; } }; Node* root = nullptr; void insert(int x, int val, Node*& v, int b, int e) { if(x < b || e < x) return; if(v == nullptr) v = new Node(); if(b == e) { v->value = (v->value+val) % mod; return; } int m = (b+e)/2; insert(x,val,v->left,b,m); insert(x,val,v->right,m+1,e); v->update(); } void insert(int x, int val) { insert(x,val,root,0,mod-1); } int query(int f, int l, Node*& v, int b, int e) { if(l < b || e < f || v == nullptr) return 0; if(f <= b && e <= l) return v->value; int m = (b+e)/2; return (query(f,l,v->left,b,m) + query(f,l,v->right,m+1,e)) % mod; } int query(int f, int l) { return (l < f ? 0 : query(f,l,root,0,mod-1)); } }; int main() { ios_base::sync_with_stdio(false); cin.tie(nullptr); int n; cin >> n; vector<int> t(n); for(int& i : t) cin >> i; ST odd, even; even.insert(0,1); int last = 0; int sum = 0; for(int i : t) { sum = (sum+i) % mod; last = 0; last += (sum%2 == 0 ? even : odd).query(0,sum); last += ((sum+1)%2 == 0 ? even : odd).query(sum+1,mod-1); last %= mod; (sum%2 == 0 ? even : odd).insert(sum, last); } cout << last << "\n"; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | #include<bits/stdc++.h> using namespace std; using lld = long long; const lld mod = 1e9+7; struct ST { struct Node { Node* left = nullptr; Node* right = nullptr; int value = 0; ~Node() { delete left; delete right; } void update() { value = ((left == nullptr ? 0 : left->value) + (right == nullptr ? 0 : right->value)) % mod; } }; Node* root = nullptr; void insert(int x, int val, Node*& v, int b, int e) { if(x < b || e < x) return; if(v == nullptr) v = new Node(); if(b == e) { v->value = (v->value+val) % mod; return; } int m = (b+e)/2; insert(x,val,v->left,b,m); insert(x,val,v->right,m+1,e); v->update(); } void insert(int x, int val) { insert(x,val,root,0,mod-1); } int query(int f, int l, Node*& v, int b, int e) { if(l < b || e < f || v == nullptr) return 0; if(f <= b && e <= l) return v->value; int m = (b+e)/2; return (query(f,l,v->left,b,m) + query(f,l,v->right,m+1,e)) % mod; } int query(int f, int l) { return (l < f ? 0 : query(f,l,root,0,mod-1)); } }; int main() { ios_base::sync_with_stdio(false); cin.tie(nullptr); int n; cin >> n; vector<int> t(n); for(int& i : t) cin >> i; ST odd, even; even.insert(0,1); int last = 0; int sum = 0; for(int i : t) { sum = (sum+i) % mod; last = 0; last += (sum%2 == 0 ? even : odd).query(0,sum); last += ((sum+1)%2 == 0 ? even : odd).query(sum+1,mod-1); last %= mod; (sum%2 == 0 ? even : odd).insert(sum, last); } cout << last << "\n"; return 0; } |