#include <bits/stdc++.h>
using namespace std;
vector <long long> solveIncr(int n, int m, vector <vector<long long>> &tab) {
for (int i = 0; i < n; i++) {
tab[i].insert(tab[i].begin(), 0);
for (int j = 1; j <= m; j++) {
tab[i][j] += tab[i][j - 1];
}
}
vector <int> order(n);
iota(order.begin(), order.end(), 0);
sort(order.begin(), order.end(), [&](int i, int j) { return tab[i].back() > tab[j].back(); });
int k = n * m;
vector <long long> sum(k + 1, 0);
for (int j = 0; j <= m; j++) {
vector <long long> bestSuff(n);
for (int i = n - 1; i >= 0; i--) {
bestSuff[i] = tab[order[i]][j];
if (i < n - 1) {
bestSuff[i] = max(bestSuff[i], bestSuff[i + 1]);
}
}
int cntTotals = 0;
long long sumTotals = 0, minDiff = 0;
for (int i = 0; i <= n; i++) {
if (i > 0) {
int cntHere = cntTotals - m + j;
long long sumHere = sumTotals - minDiff;
sum[cntHere] = max(sum[cntHere], sumHere);
}
if (i < n) {
int cntHere = cntTotals + j;
long long sumHere = sumTotals + bestSuff[i];
sum[cntHere] = max(sum[cntHere], sumHere);
}
if (i < n) {
cntTotals += m;
sumTotals += tab[order[i]].back();
long long diff = tab[order[i]].back() - tab[order[i]][j];
if (i == 0) {
minDiff = diff;
} else {
minDiff = min(minDiff, diff);
}
}
}
}
return sum;
}
long long solve(int n, int m, int k, vector <vector<long long>> &tab) {
vector <long long> sumDecr;
vector <vector<long long>> tabIncr;
for (int i = 0; i < n; i++) {
bool hasStrictIncr = false;
for (int j = 0; j < m - 1; j++) if (tab[i][j] < tab[i][j + 1]) {
hasStrictIncr = true;
break;
}
if (hasStrictIncr) {
tabIncr.push_back(tab[i]);
} else {
copy(tab[i].begin(), tab[i].end(), back_inserter(sumDecr));
}
}
sort(sumDecr.begin(), sumDecr.end(), greater <long long> ());
for (int i = 1; i < (int) sumDecr.size(); i++) {
sumDecr[i] += sumDecr[i - 1];
}
sumDecr.insert(sumDecr.begin(), 0);
auto sumIncr = solveIncr(tabIncr.size(), m, tabIncr);
long long ans = 0;
for (int nDecr = 0; nDecr <= k; nDecr++) {
int nIncr = k - nDecr;
if (nDecr < sumDecr.size() && nIncr < sumIncr.size()) {
ans = max(ans, sumDecr[nDecr] + sumIncr[nIncr]);
}
}
return ans;
}
int main() {
ios_base::sync_with_stdio(false);
int n, m, k;
cin >> n >> m >> k;
vector <vector<long long>> tab(n, vector <long long> (m));
for (int i = 0; i < n; i++) for (long long &x : tab[i]) {
cin >> x;
}
cout << solve(n, m, k, tab);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | #include <bits/stdc++.h> using namespace std; vector <long long> solveIncr(int n, int m, vector <vector<long long>> &tab) { for (int i = 0; i < n; i++) { tab[i].insert(tab[i].begin(), 0); for (int j = 1; j <= m; j++) { tab[i][j] += tab[i][j - 1]; } } vector <int> order(n); iota(order.begin(), order.end(), 0); sort(order.begin(), order.end(), [&](int i, int j) { return tab[i].back() > tab[j].back(); }); int k = n * m; vector <long long> sum(k + 1, 0); for (int j = 0; j <= m; j++) { vector <long long> bestSuff(n); for (int i = n - 1; i >= 0; i--) { bestSuff[i] = tab[order[i]][j]; if (i < n - 1) { bestSuff[i] = max(bestSuff[i], bestSuff[i + 1]); } } int cntTotals = 0; long long sumTotals = 0, minDiff = 0; for (int i = 0; i <= n; i++) { if (i > 0) { int cntHere = cntTotals - m + j; long long sumHere = sumTotals - minDiff; sum[cntHere] = max(sum[cntHere], sumHere); } if (i < n) { int cntHere = cntTotals + j; long long sumHere = sumTotals + bestSuff[i]; sum[cntHere] = max(sum[cntHere], sumHere); } if (i < n) { cntTotals += m; sumTotals += tab[order[i]].back(); long long diff = tab[order[i]].back() - tab[order[i]][j]; if (i == 0) { minDiff = diff; } else { minDiff = min(minDiff, diff); } } } } return sum; } long long solve(int n, int m, int k, vector <vector<long long>> &tab) { vector <long long> sumDecr; vector <vector<long long>> tabIncr; for (int i = 0; i < n; i++) { bool hasStrictIncr = false; for (int j = 0; j < m - 1; j++) if (tab[i][j] < tab[i][j + 1]) { hasStrictIncr = true; break; } if (hasStrictIncr) { tabIncr.push_back(tab[i]); } else { copy(tab[i].begin(), tab[i].end(), back_inserter(sumDecr)); } } sort(sumDecr.begin(), sumDecr.end(), greater <long long> ()); for (int i = 1; i < (int) sumDecr.size(); i++) { sumDecr[i] += sumDecr[i - 1]; } sumDecr.insert(sumDecr.begin(), 0); auto sumIncr = solveIncr(tabIncr.size(), m, tabIncr); long long ans = 0; for (int nDecr = 0; nDecr <= k; nDecr++) { int nIncr = k - nDecr; if (nDecr < sumDecr.size() && nIncr < sumIncr.size()) { ans = max(ans, sumDecr[nDecr] + sumIncr[nIncr]); } } return ans; } int main() { ios_base::sync_with_stdio(false); int n, m, k; cin >> n >> m >> k; vector <vector<long long>> tab(n, vector <long long> (m)); for (int i = 0; i < n; i++) for (long long &x : tab[i]) { cin >> x; } cout << solve(n, m, k, tab); return 0; } |
English