#include <bits/stdc++.h>
#pragma GCC optimize("Ofast,unroll-loops")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,fma,abm,mmx,avx,avx2,tune=native")
#pragma GCC target("sse,sse2,abm,mmx,sse3,tune=native")
using namespace std;
typedef long long lld;
typedef pair<int, int> pii;
typedef pair<lld, lld> pll;
#define ff first
#define dd second
#define mp make_pair
#define pb push_back
#define sz size()
#define For(i, s, a) for(int i = s; i < a; ++i)
#define all(x) (x).begin(), (x).end()
#define make_unique(x) (x).erase(unique(all(x)), (x).end())
map<tuple<int, int, vector<int>>, lld> m;
lld solve(int a, int b, vector<int> c) {
if (!a || !b)
return 0;
if (c.empty())
return -1;
auto xd = m.find({a, b, c});
if (xd != m.end()) {
return (*xd).second;
}
int t = c.back();
c.pop_back();
lld wyn = (lld)(a / t) * (lld)(b / t);
lld wa = solve(a % t, (b / t) * t, c);
lld wb = solve(a, b % t, c);
if (wa == -1 || wb == -1)
return -1;
m[{a, b, c}] = wyn + wa + wb;
return wyn + wa + wb;
}
int main(void) {
int a, b;
scanf("%d%d", &a, &b);
int n;
scanf("%d", &n);
vector<int> c(n, 0);
For (i, 0, n)
scanf("%d", &c[i]);
printf("%lld", solve(a, b, c));
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #include <bits/stdc++.h> #pragma GCC optimize("Ofast,unroll-loops") //#pragma GCC target("sse,sse2,sse3,ssse3,sse4,fma,abm,mmx,avx,avx2,tune=native") #pragma GCC target("sse,sse2,abm,mmx,sse3,tune=native") using namespace std; typedef long long lld; typedef pair<int, int> pii; typedef pair<lld, lld> pll; #define ff first #define dd second #define mp make_pair #define pb push_back #define sz size() #define For(i, s, a) for(int i = s; i < a; ++i) #define all(x) (x).begin(), (x).end() #define make_unique(x) (x).erase(unique(all(x)), (x).end()) map<tuple<int, int, vector<int>>, lld> m; lld solve(int a, int b, vector<int> c) { if (!a || !b) return 0; if (c.empty()) return -1; auto xd = m.find({a, b, c}); if (xd != m.end()) { return (*xd).second; } int t = c.back(); c.pop_back(); lld wyn = (lld)(a / t) * (lld)(b / t); lld wa = solve(a % t, (b / t) * t, c); lld wb = solve(a, b % t, c); if (wa == -1 || wb == -1) return -1; m[{a, b, c}] = wyn + wa + wb; return wyn + wa + wb; } int main(void) { int a, b; scanf("%d%d", &a, &b); int n; scanf("%d", &n); vector<int> c(n, 0); For (i, 0, n) scanf("%d", &c[i]); printf("%lld", solve(a, b, c)); } |
English