1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# V11
pr = []
line = input().split()
n = int(line[0])
penalty = int(line[1])


for i in range(n):
    line = input().split()
    pr.append((int(line[0]),int(line[1])))
    max_piramid_cost = 0
    
pr.sort(reverse=True)


def build_piramid(piramid_l, next_element, next_i, prev_color, cost, marker):

    current_cost=0
    current_color=0
    is_penalty_applicable=0
    global max_piramid_cost
    
    if next_element:
        piramid_l.append(next_element)       
        current_color = next_element[1]

        if (prev_color !=0) and (prev_color != current_color):
            is_penalty_applicable=1

        current_cost = cost + next_element[0] - (is_penalty_applicable * penalty)

        if max_piramid_cost < current_cost: 
            max_piramid_cost = current_cost
        #print(">>>", current_cost, piramid_l, next_i, marker)
    
    if next_i == n:
        # print("Final piramid:", piramid_l)
        return piramid_l
    else:
        # Split case (at least two and more)
        if (next_i < n-1) and (pr[next_i][0] == pr[next_i+1][0]):      
            #print("--->","split")
            
            dup_cnt = 0
            dup_element = pr[next_i][0]
            
            while True:  # to improve it!!!
                dup_cnt = dup_cnt + 1
                if pr[next_i + dup_cnt][0] != dup_element: break

            for cnt in range(dup_cnt):
                build_piramid(piramid_l[:], pr[next_i + cnt],   next_i + dup_cnt, current_color, current_cost, f"b{cnt+1}")   
                
        else:
            #gap case
            if next_i < n-1:
                # skip if duplicates.
                dup_cnt_1 = 1
                dup_element_1 = pr[next_i+1][0]
                
                while True:  # to improve it!!!
                    dup_cnt_1 = dup_cnt_1 + 1
                    if ((next_i + dup_cnt_1) >= n-1 ) or (pr[next_i + dup_cnt_1][0] != dup_element_1): break

                # print("===>", pr, next_i + dup_cnt_1)

                if (dup_cnt_1>2):
                    if (len(pr) > next_i + dup_cnt_1): 
                        build_piramid(piramid_l[:], pr[next_i+dup_cnt_1], next_i+dup_cnt_1+1, current_color, current_cost,"e")    
                else:
                    build_piramid(piramid_l[:], pr[next_i+1], next_i+2, current_color, current_cost,"e")
                    
            #regular case                
            build_piramid(piramid_l, pr[next_i], next_i+1, current_color, current_cost,"d")

#print(pr)
# build_piramid([],None,0,0,0,"a")


for i in range(n+1):
    build_piramid([],None,0,0,0,"a")
    if len(pr) == 0: 
        break
    del pr[0]
    n=n-1    

print(max_piramid_cost)