use std::{io, iter::repeat_n};
type Data = Vec<Vec<i64>>;
/*
NOTE:
Let's divide the set into increasing and decreasing stacks I, D.
* It is never optimal to take more then one non-full stack from I (it's always
more beneficial to take more moves from one)
* When taking from D, since they are decreasing we can treat each element on it's
own – since we are being greedy, we can just always take the max one and it works.
Let's call D' the set of all elements of D.
* When resolving draws from I, we always want to take the lexicographically smaller one
when considering their prefix sums.
NOTE:
As such the overall idea is as follows:
* Find the optimal solution Opt which does not take into account any partials from I.
* Find the best full stack S in I we did not take the Opt (we will include it if we have space).
* Consider all partials P in I as such:
- If it's not in Opt, drop elements from D' to make space.
- If it is in Opt, then we need to either:
- Fill up the remainder with elements from D'
- Fill up the remainer with stack S
* The preprocessing is all done to make the transitions fast
* Overall complexity is n*m partials * (1~2 cases for each) so linear in the bounded n*m.
*/
macro_rules! eprintln {
($($x:expr),+) => {};
}
macro_rules! eprint {
($($x:expr),+) => {};
}
fn calc_partial(stack: &[i64], fill: &[i64]) -> i64 {
// Stack -> partial sums, increasing
// Fill -> not partial sums, decreasing
let mut d_partial = 0;
let mut res = 0;
eprint!("calc_partial({:?}, {:?}) = max(", stack, fill);
for (s, d) in stack
.iter()
.take(fill.len())
.rev()
.zip(fill.iter().rev().take(stack.len()).rev())
{
eprint!("{}, ", s + d_partial);
res = res.max(s + d_partial);
d_partial += d;
}
res = res.max(d_partial);
eprintln!("{d_partial}) = {res}");
res
}
fn solve(k: i64, data: Data) -> i64 {
let n = data.len();
let m = data.first().unwrap().len();
let mut d: Vec<i64> = repeat_n(0, n * m).collect();
let mut i: Vec<Vec<i64>> = Vec::new();
for mut line in data {
if line.iter().max() == line.first() {
d.append(&mut line)
} else {
i.push(
line.iter()
.scan(0, |state, val| {
*state += val;
Some(*state)
})
.collect(),
);
}
}
d.sort();
i.sort_by_key(|x| x.last().copied().unwrap());
let i_snap = i.clone();
let mut i_optcnt = 0;
let mut res_opt = 0;
let mut d_optvec = Vec::new();
let mut d_sum = 0;
for _ in 0..m {
d_sum += d.last().unwrap();
}
// NOTE: We take everything until k_left is 0.
// Then all partials are done by deleting old elements from d.
let mut k_left = k as usize;
while let Some(i_sum) = i.last().and_then(|x| x.last().cloned()) {
if k_left < m {
// https://github.com/rust-lang/rust/issues/53667... :(
break;
}
if i_sum >= d_sum {
i.pop();
res_opt += i_sum;
i_optcnt += 1;
k_left -= m;
} else {
let val = d.pop().unwrap();
res_opt += val;
d_sum -= val;
d_sum += d[d.len() - m];
d_optvec.push(val);
k_left -= 1;
}
}
eprintln!(
"d.len(): {}; k_left: {}; k_left+m: {};",
d.len(),
k_left,
k_left + m
);
// assert!(d.len() >= k_left + m);
while k_left > 0 {
let val = d.pop().unwrap();
res_opt += val;
d_optvec.push(val);
k_left -= 1;
}
eprintln!("i: {:?}", i);
eprintln!("d: {:?}", d);
eprintln!("i_optcnt: {i_optcnt}");
eprintln!("d_optvec: {:?}", d_optvec);
let mut res = res_opt;
// We now have an optimal solution w/o partials.
// Case 1: take partial of an already taken stack S and fill with elements from D
let d_optrev: Vec<i64> = d.iter().rev().take(m).copied().collect();
let d_optvec: Vec<i64> = d_optvec.iter().rev().take(m).rev().copied().collect();
let opt_del: i64 = d_optvec.iter().sum();
for stack in i_snap.iter().rev().take(i_optcnt) {
let del = stack.last().unwrap();
let par = calc_partial(stack, &d_optrev);
res = res.max(res_opt + par - del);
eprintln!("l1 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
}
// Case 2: drop the worst taken stack S, take partial of a not-taken stack S' and drop from D
// Case 3: take partial of a not-taken stack and drop elements from D
for stack in i_snap.iter().rev().skip(i_optcnt) {
if i_optcnt > 0 {
let del = i_snap[i_snap.len() - i_optcnt].last().unwrap();
let par = calc_partial(stack, &d_optrev);
res = res.max(res_opt + par - del);
eprintln!("l2 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
}
let del = opt_del;
let par = calc_partial(stack, &d_optvec);
res = res.max(res_opt + par - del);
eprintln!("l3 {stack:?}: {res:?} ({res_opt} - {del} + {par})");
}
if let Some(i_last) = i.last() {
let last_del = i_last.last().unwrap();
// Case 4: replace optimal S with S', and take partial of S dropping from D
for stack in i_snap.iter().rev().take(i_optcnt) {
let del = opt_del;
let this_del = stack.last().unwrap();
let par = calc_partial(stack, &d_optvec);
res = res.max(res_opt + par - del + last_del - this_del);
eprintln!(
"l4 {stack:?}: {res:?} ({res_opt} - {del} + {last_del} - {this_del} + {par})"
);
}
}
res
}
pub fn main() -> io::Result<()> {
let mut buf = String::new();
io::stdin().read_line(&mut buf)?;
let [n, m, k] = buf
.split_whitespace()
.map(|x| x.parse().unwrap())
.collect::<Vec<i64>>()[..]
else {
panic!("Expected 3 integers!");
};
let mut data: Data = Vec::new();
for _ in 0..n {
buf.clear();
io::stdin().read_line(&mut buf)?;
data.push(buf.split_whitespace().map(|x| x.parse().unwrap()).collect());
}
#[allow(clippy::overly_complex_bool_expr)]
let res = if false && k == n * m {
// Edgecase when we take literally everything
data.iter().map(|v| v.iter().sum::<i64>()).sum()
} else {
solve(k, data)
};
println!("{res}");
Ok(())
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | use std::{io, iter::repeat_n}; type Data = Vec<Vec<i64>>; /* NOTE: Let's divide the set into increasing and decreasing stacks I, D. * It is never optimal to take more then one non-full stack from I (it's always more beneficial to take more moves from one) * When taking from D, since they are decreasing we can treat each element on it's own – since we are being greedy, we can just always take the max one and it works. Let's call D' the set of all elements of D. * When resolving draws from I, we always want to take the lexicographically smaller one when considering their prefix sums. NOTE: As such the overall idea is as follows: * Find the optimal solution Opt which does not take into account any partials from I. * Find the best full stack S in I we did not take the Opt (we will include it if we have space). * Consider all partials P in I as such: - If it's not in Opt, drop elements from D' to make space. - If it is in Opt, then we need to either: - Fill up the remainder with elements from D' - Fill up the remainer with stack S * The preprocessing is all done to make the transitions fast * Overall complexity is n*m partials * (1~2 cases for each) so linear in the bounded n*m. */ macro_rules! eprintln { ($($x:expr),+) => {}; } macro_rules! eprint { ($($x:expr),+) => {}; } fn calc_partial(stack: &[i64], fill: &[i64]) -> i64 { // Stack -> partial sums, increasing // Fill -> not partial sums, decreasing let mut d_partial = 0; let mut res = 0; eprint!("calc_partial({:?}, {:?}) = max(", stack, fill); for (s, d) in stack .iter() .take(fill.len()) .rev() .zip(fill.iter().rev().take(stack.len()).rev()) { eprint!("{}, ", s + d_partial); res = res.max(s + d_partial); d_partial += d; } res = res.max(d_partial); eprintln!("{d_partial}) = {res}"); res } fn solve(k: i64, data: Data) -> i64 { let n = data.len(); let m = data.first().unwrap().len(); let mut d: Vec<i64> = repeat_n(0, n * m).collect(); let mut i: Vec<Vec<i64>> = Vec::new(); for mut line in data { if line.iter().max() == line.first() { d.append(&mut line) } else { i.push( line.iter() .scan(0, |state, val| { *state += val; Some(*state) }) .collect(), ); } } d.sort(); i.sort_by_key(|x| x.last().copied().unwrap()); let i_snap = i.clone(); let mut i_optcnt = 0; let mut res_opt = 0; let mut d_optvec = Vec::new(); let mut d_sum = 0; for _ in 0..m { d_sum += d.last().unwrap(); } // NOTE: We take everything until k_left is 0. // Then all partials are done by deleting old elements from d. let mut k_left = k as usize; while let Some(i_sum) = i.last().and_then(|x| x.last().cloned()) { if k_left < m { // https://github.com/rust-lang/rust/issues/53667... :( break; } if i_sum >= d_sum { i.pop(); res_opt += i_sum; i_optcnt += 1; k_left -= m; } else { let val = d.pop().unwrap(); res_opt += val; d_sum -= val; d_sum += d[d.len() - m]; d_optvec.push(val); k_left -= 1; } } eprintln!( "d.len(): {}; k_left: {}; k_left+m: {};", d.len(), k_left, k_left + m ); // assert!(d.len() >= k_left + m); while k_left > 0 { let val = d.pop().unwrap(); res_opt += val; d_optvec.push(val); k_left -= 1; } eprintln!("i: {:?}", i); eprintln!("d: {:?}", d); eprintln!("i_optcnt: {i_optcnt}"); eprintln!("d_optvec: {:?}", d_optvec); let mut res = res_opt; // We now have an optimal solution w/o partials. // Case 1: take partial of an already taken stack S and fill with elements from D let d_optrev: Vec<i64> = d.iter().rev().take(m).copied().collect(); let d_optvec: Vec<i64> = d_optvec.iter().rev().take(m).rev().copied().collect(); let opt_del: i64 = d_optvec.iter().sum(); for stack in i_snap.iter().rev().take(i_optcnt) { let del = stack.last().unwrap(); let par = calc_partial(stack, &d_optrev); res = res.max(res_opt + par - del); eprintln!("l1 {stack:?}: {res:?} ({res_opt} - {del} + {par})"); } // Case 2: drop the worst taken stack S, take partial of a not-taken stack S' and drop from D // Case 3: take partial of a not-taken stack and drop elements from D for stack in i_snap.iter().rev().skip(i_optcnt) { if i_optcnt > 0 { let del = i_snap[i_snap.len() - i_optcnt].last().unwrap(); let par = calc_partial(stack, &d_optrev); res = res.max(res_opt + par - del); eprintln!("l2 {stack:?}: {res:?} ({res_opt} - {del} + {par})"); } let del = opt_del; let par = calc_partial(stack, &d_optvec); res = res.max(res_opt + par - del); eprintln!("l3 {stack:?}: {res:?} ({res_opt} - {del} + {par})"); } if let Some(i_last) = i.last() { let last_del = i_last.last().unwrap(); // Case 4: replace optimal S with S', and take partial of S dropping from D for stack in i_snap.iter().rev().take(i_optcnt) { let del = opt_del; let this_del = stack.last().unwrap(); let par = calc_partial(stack, &d_optvec); res = res.max(res_opt + par - del + last_del - this_del); eprintln!( "l4 {stack:?}: {res:?} ({res_opt} - {del} + {last_del} - {this_del} + {par})" ); } } res } pub fn main() -> io::Result<()> { let mut buf = String::new(); io::stdin().read_line(&mut buf)?; let [n, m, k] = buf .split_whitespace() .map(|x| x.parse().unwrap()) .collect::<Vec<i64>>()[..] else { panic!("Expected 3 integers!"); }; let mut data: Data = Vec::new(); for _ in 0..n { buf.clear(); io::stdin().read_line(&mut buf)?; data.push(buf.split_whitespace().map(|x| x.parse().unwrap()).collect()); } #[allow(clippy::overly_complex_bool_expr)] let res = if false && k == n * m { // Edgecase when we take literally everything data.iter().map(|v| v.iter().sum::<i64>()).sum() } else { solve(k, data) }; println!("{res}"); Ok(()) } |
English