#include <bits/stdc++.h>
using namespace std;
const int MX=300300;
int n,m,cnt,numa,numv,idx[MX];
long long a[MX],sz[MX],ans;
vector<long long> all[MX];
multiset<long long> ins[MX],oth[MX];
bool cmp(int i, int j) {
return all[i][m-1]>all[j][m-1];
}
int main() {
scanf("%d%d%d",&n,&m,&cnt);
vector<long long> v(m);
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) scanf("%lld",&v[j]);
if (v[0]<v[m-1]) {
for (int j=1; j<m; j++) v[j]+=v[j-1];
for (int j=0; j<m; j++) oth[j].insert(-v[j]);
idx[numv]=numv;
all[numv++]=v;
} else for (int j=0; j<m; j++) a[++numa]=v[j];
}
sort(idx,idx+numv,cmp);
sort(a+1,a+numa+1);
reverse(a+1,a+numa+1);
for (int i=1; i<=numa; i++) a[i]+=a[i-1];
if (cnt<=numa) ans=a[cnt];
long long sk=0;
for (int ii=0, ck=0; ii<numv; ii++) {
int i=idx[ii];
for (int j=0; j<m && ck+j<cnt; j++) if (ck+j+1+numa>=cnt) {
ans=max(ans,sk-*oth[j].begin()+a[cnt-ck-j-1]);
if (!ins[j].empty()) ans=max(ans,sk+all[i][m-1]-*ins[j].begin()+a[cnt-ck-j-1]);
}
ck+=m;
if (ck>cnt) break;
sk+=all[i][m-1];
if (ck<=cnt && ck+numa>=cnt) ans=max(ans,sk+a[cnt-ck]);
for (int j=0; j<m; j++) {
oth[j].erase(oth[j].find(-all[i][j]));
ins[j].insert(all[i][m-1]-all[i][j]);
}
}
printf("%lld\n",ans);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | #include <bits/stdc++.h> using namespace std; const int MX=300300; int n,m,cnt,numa,numv,idx[MX]; long long a[MX],sz[MX],ans; vector<long long> all[MX]; multiset<long long> ins[MX],oth[MX]; bool cmp(int i, int j) { return all[i][m-1]>all[j][m-1]; } int main() { scanf("%d%d%d",&n,&m,&cnt); vector<long long> v(m); for (int i=0; i<n; i++) { for (int j=0; j<m; j++) scanf("%lld",&v[j]); if (v[0]<v[m-1]) { for (int j=1; j<m; j++) v[j]+=v[j-1]; for (int j=0; j<m; j++) oth[j].insert(-v[j]); idx[numv]=numv; all[numv++]=v; } else for (int j=0; j<m; j++) a[++numa]=v[j]; } sort(idx,idx+numv,cmp); sort(a+1,a+numa+1); reverse(a+1,a+numa+1); for (int i=1; i<=numa; i++) a[i]+=a[i-1]; if (cnt<=numa) ans=a[cnt]; long long sk=0; for (int ii=0, ck=0; ii<numv; ii++) { int i=idx[ii]; for (int j=0; j<m && ck+j<cnt; j++) if (ck+j+1+numa>=cnt) { ans=max(ans,sk-*oth[j].begin()+a[cnt-ck-j-1]); if (!ins[j].empty()) ans=max(ans,sk+all[i][m-1]-*ins[j].begin()+a[cnt-ck-j-1]); } ck+=m; if (ck>cnt) break; sk+=all[i][m-1]; if (ck<=cnt && ck+numa>=cnt) ans=max(ans,sk+a[cnt-ck]); for (int j=0; j<m; j++) { oth[j].erase(oth[j].find(-all[i][j])); ins[j].insert(all[i][m-1]-all[i][j]); } } printf("%lld\n",ans); return 0; } |
English