#include<bits/stdc++.h> #define ST first #define ND second #define ll long long using namespace std; ll MAX = (ll)1<<30; ll MOD = 1e9 + 7; struct Node{ ll firstValue = 0; ll secondValue = 0; Node* l = nullptr; Node* r = nullptr; }; Node* root; ll getFirstValue(Node* node){ if(node == nullptr) return 0; return node->firstValue; } ll getSecondValue(Node* node){ if(node == nullptr) return 0; return node->secondValue; } Node* leftChild(Node* node){ if(node->l == nullptr) node->l = new Node(); return node->l; } Node* rightChild(Node* node){ if(node->r == nullptr) node->r = new Node(); return node->r; } ll query(ll a, ll b, Node* curNode, ll lo, ll hi, bool secondValue){ if(b <= a) return 0; if(a == lo and b == hi){ if(secondValue) return curNode->secondValue; else return curNode->firstValue; } ll mid = (lo+hi)/2; ll x = query(a, min(mid, b), leftChild(curNode), lo, mid, secondValue); ll y = query(max(a, mid), b, rightChild(curNode), mid, hi, secondValue); return (x+y)%MOD; } void update(ll x, ll val, Node* curNode, ll lo, ll hi){ if(lo+1 == hi){ if(x % 2 == 1){ ll tmp = curNode->secondValue; curNode->secondValue = (tmp + val) % MOD; } else { ll tmp = curNode->firstValue; curNode->firstValue = (tmp + val) % MOD; } return; } ll mid = (lo+hi)/2; if(x >= lo && x < mid) update(x, val, leftChild(curNode), lo, mid); else if(x >= mid && x < hi) update(x, val, rightChild(curNode), mid, hi); curNode->firstValue = (getFirstValue(curNode->l)+getFirstValue(curNode->r))%MOD; curNode->secondValue = (getSecondValue(curNode->l)+getSecondValue(curNode->r))%MOD; } int main(){ cin.tie(0); cout.tie(0); ios_base::sync_with_stdio(false); root = new Node(); int n; cin >> n; vector<ll> tab(n); for(int i = 0; i < n; i++) cin >> tab[i]; ll shiftBy = 0; int parz = shiftBy%2; for(int i = 0; i < n; i++){ parz = shiftBy%2; if(i == 0) update(tab[i],1,root,0,MAX); else { ll x = query(shiftBy, MOD, root, 0, MAX, parz); ll y = query(0, shiftBy, root, 0, MAX, !parz); x += y; shiftBy = (shiftBy-tab[i])%MOD; if(shiftBy < 0) shiftBy += MOD; ll start = (shiftBy+tab[i])%MOD; if(start < 0) start += MOD; update(start,x % MOD,root,0,MAX); } } parz = shiftBy%2; ll firstAns = query(shiftBy, MOD, root, 0, MAX, parz); ll secondAns = query(0, shiftBy, root, 0, MAX, !parz); ll ans = (firstAns + secondAns) % MOD; if(ans < 0) ans += MOD; cout << ans << "\n"; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | #include<bits/stdc++.h> #define ST first #define ND second #define ll long long using namespace std; ll MAX = (ll)1<<30; ll MOD = 1e9 + 7; struct Node{ ll firstValue = 0; ll secondValue = 0; Node* l = nullptr; Node* r = nullptr; }; Node* root; ll getFirstValue(Node* node){ if(node == nullptr) return 0; return node->firstValue; } ll getSecondValue(Node* node){ if(node == nullptr) return 0; return node->secondValue; } Node* leftChild(Node* node){ if(node->l == nullptr) node->l = new Node(); return node->l; } Node* rightChild(Node* node){ if(node->r == nullptr) node->r = new Node(); return node->r; } ll query(ll a, ll b, Node* curNode, ll lo, ll hi, bool secondValue){ if(b <= a) return 0; if(a == lo and b == hi){ if(secondValue) return curNode->secondValue; else return curNode->firstValue; } ll mid = (lo+hi)/2; ll x = query(a, min(mid, b), leftChild(curNode), lo, mid, secondValue); ll y = query(max(a, mid), b, rightChild(curNode), mid, hi, secondValue); return (x+y)%MOD; } void update(ll x, ll val, Node* curNode, ll lo, ll hi){ if(lo+1 == hi){ if(x % 2 == 1){ ll tmp = curNode->secondValue; curNode->secondValue = (tmp + val) % MOD; } else { ll tmp = curNode->firstValue; curNode->firstValue = (tmp + val) % MOD; } return; } ll mid = (lo+hi)/2; if(x >= lo && x < mid) update(x, val, leftChild(curNode), lo, mid); else if(x >= mid && x < hi) update(x, val, rightChild(curNode), mid, hi); curNode->firstValue = (getFirstValue(curNode->l)+getFirstValue(curNode->r))%MOD; curNode->secondValue = (getSecondValue(curNode->l)+getSecondValue(curNode->r))%MOD; } int main(){ cin.tie(0); cout.tie(0); ios_base::sync_with_stdio(false); root = new Node(); int n; cin >> n; vector<ll> tab(n); for(int i = 0; i < n; i++) cin >> tab[i]; ll shiftBy = 0; int parz = shiftBy%2; for(int i = 0; i < n; i++){ parz = shiftBy%2; if(i == 0) update(tab[i],1,root,0,MAX); else { ll x = query(shiftBy, MOD, root, 0, MAX, parz); ll y = query(0, shiftBy, root, 0, MAX, !parz); x += y; shiftBy = (shiftBy-tab[i])%MOD; if(shiftBy < 0) shiftBy += MOD; ll start = (shiftBy+tab[i])%MOD; if(start < 0) start += MOD; update(start,x % MOD,root,0,MAX); } } parz = shiftBy%2; ll firstAns = query(shiftBy, MOD, root, 0, MAX, parz); ll secondAns = query(0, shiftBy, root, 0, MAX, !parz); ll ans = (firstAns + secondAns) % MOD; if(ans < 0) ans += MOD; cout << ans << "\n"; } |