1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys
from heapq import heappop, heappush

n, m, k = [int(component) for component in input().split(" ")]

normal_stacks = []
flipped_stacks = []

for i in range(n):
  stack = list(map(int,sys.stdin.readline().split()))
  if stack[0] <= stack[-1]:
    normal_stacks.append(stack)
  else:
    flipped_stacks.append(stack)

# print(n, m, k, normal_stacks, flipped_stacks)

# Start out with sum from all normal

ans = sum(sum(normal_stack) for normal_stack in normal_stacks)
pancakes = sum(len(normal_stack) for normal_stack in normal_stacks)

# Build minheap : (pancake_value, stack_index, pancake_index)
normal_minheap = []
for stack_index in range(len(normal_stacks)):
  heappush(normal_minheap, (normal_stacks[stack_index][-1], stack_index, len(normal_stacks[stack_index])-1))

# print(normal_minheap)

# Subtract from minheap

while pancakes > k:
  pancake_value, stack_index, pancake_index = heappop(normal_minheap)
  # print(f"Declining {pancake_value=} {stack_index=} {pancake_index=}")
  ans -= pancake_value
  pancakes -= 1
  # print(f"{ans=} {pancakes=}")
  if pancake_index != 0:
    heappush(normal_minheap, (normal_stacks[stack_index][pancake_index-1], stack_index, pancake_index-1))

# Build maxheap: (-pancake_value, stack_index, pancake_index)
flipped_maxheap = []
for stack_index in range(len(flipped_stacks)):
  heappush(flipped_maxheap, (-flipped_stacks[stack_index][0], stack_index, 0))

# If we're below k, we might as well eat
while pancakes < k and flipped_maxheap:
  pancake_value, stack_index, pancake_index = heappop(flipped_maxheap)
  pancake_value *= -1
  # print(f"Adding {pancake_value=} {stack_index=} {pancake_index=}")
  ans += pancake_value
  pancakes += 1
  # print(f"{ans=} {pancakes=}")
  if pancake_index != len(flipped_stacks[stack_index])-1:
    heappush(flipped_maxheap, (-flipped_stacks[stack_index][pancake_index+1], stack_index, pancake_index+1))

# We can remove values from the bottoms of one of the normal stacks
# We can add values from the tops of one of the flipped stacks

# so long as we get more value from adding a top, we should keep doing it

while normal_minheap and flipped_maxheap and normal_minheap[0][0] <= -flipped_maxheap[0][0]:
  outgoing_pancake_value, outgoing_stack_index, outgoing_pancake_index = heappop(normal_minheap)
  incoming_pancake_value, incoming_stack_index, incoming_pancake_index = heappop(flipped_maxheap)
  incoming_pancake_value *= -1

  # print(f"Declining {outgoing_pancake_value=} {outgoing_stack_index=} {outgoing_pancake_index=} in favor of {incoming_pancake_value=} {incoming_stack_index=} {incoming_pancake_index=}")

  ans -= outgoing_pancake_value
  ans += incoming_pancake_value

  # print(f"{ans=} {pancakes=}")
  if outgoing_pancake_index != 0:
    heappush(normal_minheap, (normal_stacks[outgoing_stack_index][outgoing_pancake_index-1], outgoing_stack_index, outgoing_pancake_index-1))
  if incoming_pancake_index != len(flipped_stacks[incoming_stack_index])-1:
    heappush(flipped_maxheap, (-flipped_stacks[incoming_stack_index][incoming_pancake_index+1], incoming_stack_index, incoming_pancake_index+1))

print(ans)