n = int(input()) n2 = n*n squares = set() for i in range(1, n+1): squares.add(i*i) l = len(squares) squares_l = sorted(list(squares)) partial_sums = {} # a^2 + b^2 for a in range(0, l): for b in range(a, l): s = squares_l[a]+squares_l[b] if s > n2: break if s in partial_sums: partial_sums[s] += 1 else: partial_sums[s] = 1 results = 0 for s in squares_l: for h in squares_l: if h >= s: break d = s - h if d in partial_sums: results += partial_sums[s-h] print(results)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | n = int(input()) n2 = n*n squares = set() for i in range(1, n+1): squares.add(i*i) l = len(squares) squares_l = sorted(list(squares)) partial_sums = {} # a^2 + b^2 for a in range(0, l): for b in range(a, l): s = squares_l[a]+squares_l[b] if s > n2: break if s in partial_sums: partial_sums[s] += 1 else: partial_sums[s] = 1 results = 0 for s in squares_l: for h in squares_l: if h >= s: break d = s - h if d in partial_sums: results += partial_sums[s-h] print(results) |