1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
n = int(input())
n2 = n*n

squares = set()
for i in range(1, n+1):
    squares.add(i*i)
l = len(squares)
squares_l = sorted(list(squares))

partial_sums = {}  # a^2 + b^2
for a in range(0, l):
    for b in range(a, l):
        s = squares_l[a]+squares_l[b]
        if s > n2:
            break
        if s in partial_sums:
            partial_sums[s] += 1
        else:
            partial_sums[s] = 1

results = 0

for s in squares_l:
    for h in squares_l:
        if h >= s:
            break
        d = s - h
        if d in partial_sums:
            results += partial_sums[s-h]
print(results)