import sys
def main():
data = sys.stdin.buffer.read().split()
it = iter(data)
n, m, k = int(next(it)), int(next(it)), int(next(it))
stacks = []
for _ in range(n):
stacks.append([int(next(it)) for _ in range(m)])
prefs = []
is_nondec = []
for row in stacks:
p, s = [0], 0
for x in row:
s += x
p.append(s)
prefs.append(p)
is_nondec.append(m == 1 or row[0] <= row[-1])
def count_and_value(lam):
total_cnt, total_val = 0, 0
for i in range(n):
if not is_nondec[i]:
# stos Nierosnący
row = stacks[i]
lo2, hi2 = 0, m
while lo2 < hi2:
mid2 = (lo2 + hi2) // 2
if row[mid2] >= lam:
lo2 = mid2 + 1
else:
hi2 = mid2
total_cnt += lo2
total_val += prefs[i][lo2]
else:
# Stos niemalejący
if prefs[i][m] >= m * lam:
total_cnt += m
total_val += prefs[i][m]
return total_cnt, total_val
lo, hi = 0, 10**12 + 1
while lo < hi:
mid = (lo + hi) // 2
if count_and_value(mid)[0] >= k:
lo = mid + 1
else:
hi = mid
lam = lo - 1
cnt, val = count_and_value(lam)
print(val - (cnt - k) * lam)
if __name__ == "__main__":
main()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import sys def main(): data = sys.stdin.buffer.read().split() it = iter(data) n, m, k = int(next(it)), int(next(it)), int(next(it)) stacks = [] for _ in range(n): stacks.append([int(next(it)) for _ in range(m)]) prefs = [] is_nondec = [] for row in stacks: p, s = [0], 0 for x in row: s += x p.append(s) prefs.append(p) is_nondec.append(m == 1 or row[0] <= row[-1]) def count_and_value(lam): total_cnt, total_val = 0, 0 for i in range(n): if not is_nondec[i]: # stos Nierosnący row = stacks[i] lo2, hi2 = 0, m while lo2 < hi2: mid2 = (lo2 + hi2) // 2 if row[mid2] >= lam: lo2 = mid2 + 1 else: hi2 = mid2 total_cnt += lo2 total_val += prefs[i][lo2] else: # Stos niemalejący if prefs[i][m] >= m * lam: total_cnt += m total_val += prefs[i][m] return total_cnt, total_val lo, hi = 0, 10**12 + 1 while lo < hi: mid = (lo + hi) // 2 if count_and_value(mid)[0] >= k: lo = mid + 1 else: hi = mid lam = lo - 1 cnt, val = count_and_value(lam) print(val - (cnt - k) * lam) if __name__ == "__main__": main() |
English