1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
use std::{io, iter::repeat_n};

type Data = Vec<Vec<i64>>;

/*
 NOTE:
 Let's divide the set into increasing and decreasing stacks I, D.
 * It is never optimal to take more then one non-full stack from I (it's always
   more beneficial to take more moves from one)
 * When taking from D, since they are decreasing we can treat each element on it's
   own – since we are being greedy, we can just always take the max one and it works.
   Let's call D' the set of all elements of D.
 * When resolving draws from I, we always want to take the lexicographically smaller one
   when considering their prefix sums.

 NOTE:
 As such the overall idea is as follows:
 * Find the optimal solution Opt which does not take into account any partials from I.
 * Find the best full stack S in I we did not take the Opt (we will include it if we have space).
 * Consider all partials P in I as such:
   - If it's not in Opt, drop elements from D' to make space.
   - If it is in Opt, then we need to either:
     - Fill up the remainder with elements from D'
     - Fill up the remainer with stack S
 * The preprocessing is all done to make the transitions fast
 * Overall complexity is n*m partials * (1~2 cases for each) so linear in the bounded n*m.
*/

macro_rules! eprintln {
    ($($x:expr),+) => {};
}
macro_rules! eprint {
    ($($x:expr),+) => {};
}

fn calc_partial(stack: &[i64], fill: &[i64]) -> i64 {
    // Stack -> partial sums, increasing
    // Fill -> not partial sums, decreasing
    let mut d_partial = 0;
    let mut res = 0;
    eprint!("calc_partial({:?}, {:?}) = max(", stack, fill);
    for (s, d) in stack
        .iter()
        .take(fill.len())
        .rev()
        .zip(fill.iter().rev().take(stack.len()).rev())
    {
        eprint!("{}, ", s + d_partial);
        res = res.max(s + d_partial);
        d_partial += d;
    }
    res = res.max(d_partial);
    eprintln!("{d_partial}) = {res}");
    res
}

fn solve(k: i64, data: Data) -> i64 {
    let n = data.len();
    let m = data.first().unwrap().len();
    let mut d: Vec<i64> = repeat_n(0, n * m).collect();
    let mut i: Vec<Vec<i64>> = Vec::new();

    for mut line in data {
        if line.iter().max() == line.first() {
            d.append(&mut line)
        } else {
            i.push(
                line.iter()
                    .scan(0, |state, val| {
                        *state += val;
                        Some(*state)
                    })
                    .collect(),
            );
        }
    }
    d.sort();
    i.sort_by_key(|x| x.last().copied().unwrap());

    let i_snap = i.clone();

    let mut i_optcnt = 0;
    let mut res_opt = 0;

    let mut d_optvec = Vec::new();
    let mut d_sum = 0;

    for _ in 0..m {
        d_sum += d.last().unwrap();
    }

    // NOTE: We take everything until k_left is 0.
    // Then all partials are done by deleting old elements from d.

    let mut k_left = k as usize;
    while let Some(i_sum) = i.last().and_then(|x| x.last().cloned()) {
        if k_left < m {
            // https://github.com/rust-lang/rust/issues/53667... :(
            break;
        }
        if i_sum >= d_sum {
            i.pop();
            res_opt += i_sum;
            i_optcnt += 1;
            k_left -= m;
        } else {
            let val = d.pop().unwrap();
            res_opt += val;
            d_sum -= val;
            d_sum += d[d.len() - m];
            d_optvec.push(val);
            k_left -= 1;
        }
    }
    eprintln!(
        "d.len(): {}; k_left: {}; k_left+m: {};",
        d.len(),
        k_left,
        k_left + m
    );
    // assert!(d.len() >= k_left + m);
    while k_left > 0 {
        let val = d.pop().unwrap();
        res_opt += val;
        d_optvec.push(val);
        k_left -= 1;
    }

    eprintln!("i: {:?}", i);
    eprintln!("d: {:?}", d);
    eprintln!("i_optcnt: {i_optcnt}");
    eprintln!("d_optvec: {:?}", d_optvec);

    let mut res = res_opt;
    // We now have an optimal solution w/o partials.
    // Case 1: take partial of an already taken stack S and fill with elements from D
    let d_optrev: Vec<i64> = d.iter().rev().take(m).copied().collect();
    let d_optvec: Vec<i64> = d_optvec.iter().rev().take(m).rev().copied().collect();
    let opt_del: i64 = d_optvec.iter().sum();

    for stack in i_snap.iter().rev().take(i_optcnt) {
        let del = stack.last().unwrap();
        let par = calc_partial(stack, &d_optrev);
        res = res.max(res_opt + par - del);
        eprintln!("l1 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
    }

    // Case 2: drop the worst taken stack S, take partial of a not-taken stack S' and drop from D
    // Case 3: take partial of a not-taken stack and drop elements from D
    for stack in i_snap.iter().rev().skip(i_optcnt) {
        if i_optcnt > 0 {
            let del = i_snap[i_snap.len() - i_optcnt].last().unwrap();
            let par = calc_partial(stack, &d_optrev);
            res = res.max(res_opt + par - del);
            eprintln!("l2 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
        }
        let del = opt_del;
        let par = calc_partial(stack, &d_optvec);
        res = res.max(res_opt + par - del);
        eprintln!("l3 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
    }
    if let Some(i_last) = i.last() {
        let last_del = i_last.last().unwrap();
        // Case 4: replace optimal S with S', and take partial of S dropping from D
        for stack in i_snap.iter().rev().take(i_optcnt) {
            let del = opt_del;
            let this_del = stack.last().unwrap();
            let par = calc_partial(stack, &d_optvec);

            res = res.max(res_opt + par - del + last_del - this_del);
            eprintln!(
                "l4 {stack:?}: {res:?} ({res_opt} - {del} + {last_del} - {this_del} + {par})"
            );
        }
    }

    res
}

pub fn main() -> io::Result<()> {
    let mut buf = String::new();
    io::stdin().read_line(&mut buf)?;

    let [n, m, k] = buf
        .split_whitespace()
        .map(|x| x.parse().unwrap())
        .collect::<Vec<i64>>()[..]
    else {
        panic!("Expected 3 integers!");
    };

    let mut data: Data = Vec::new();

    for _ in 0..n {
        buf.clear();
        io::stdin().read_line(&mut buf)?;
        data.push(buf.split_whitespace().map(|x| x.parse().unwrap()).collect());
    }

    #[allow(clippy::overly_complex_bool_expr)]
    let res = if false && k == n * m {
        // Edgecase when we take literally everything
        data.iter().map(|v| v.iter().sum::<i64>()).sum()
    } else {
        solve(k, data)
    };

    println!("{res}");

    Ok(())
}