1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#include <iostream>
#include <string>
#include <cassert>
#include <algorithm>
#include <vector>
#include <cmath>

// Config selector:
// 1 == Debug
// 0 == Release
#if 0

// Debug config
#define ASSERT(expr) assert(expr)
#define DBG(expr) expr

#else

// Release config
#define ASSERT(expr) do {} while (0)
#define DBG(expr) do {} while (0)

#endif

static bool is_square(int num)
{
	int const maybe_root = (int)std::sqrt(double(num));
	//std::cout << "is_square(" << num << "): maybe_root=" << maybe_root << '\n';
	return maybe_root * maybe_root == num;
}

// It tests is_square() for all integers in range [1; (max_n+1)**2-1].
static void test_is_square(int max_n)
{
	std::cout << "testing is_square() up to " << (max_n+1)*(max_n+1)-1 << "...\n";
	for (int i = 1; i <= max_n; ++i)
	{
		// Test numbers in range [i**2, (i+1)**2-1]
		int c = i * i;
		if (!is_square(c))
		{
			std::cout << "Error: expected is_square(" << c << ") to be true\n";
		}
		for (++c; c < (i+1)*(i+1); ++c)
		{
			if (is_square(c))
			{
				std::cout << "Error: expected is_square(" << c << ") to be false\n";
			}
		}
	}
	std::cout << "done testing is_square()\n";
}

/*
 * Counts the number of ({a, b}, h) for positive integers a, b, h such that
 * a^2 + b^2 + h^2 <= n^2
 * and the left side is a square of integer.
 *
 * Assumption: 1 <= n <= 5000.
 *
 * Complexity: O(n^3)
 * Calls to std::sqrt: O(n^3)
 */
static int compute_answer_brute_force(int n)
{
	DBG(std::cout << "called compute_answer_brute_force(" << n << ")\n");
	int const n_square = n * n;

	int result = 0;
	for (int a = 1; a < n; ++a)
	{
		int const a_square = a * a;
		for (int b = a; a_square + b * b < n_square; ++b)
		{
			int const b_square = b * b;
			for (int h = 1; a_square + b_square + h * h <= n_square; ++h)
			{
				int const h_square = h * h;
				result += is_square(a_square + b_square + h_square);
			}
		}
	}
	return result;
}

/*
 * Complexity: O(n^3)
 * Calls to std::sqrt: O(n^2)
 */
static int compute_answer_reduced_float_ops(int n)
{
	DBG(std::cout << "called compute_answer_reduced_float_ops(" << n << ")\n");
	int const n_square = n * n;

	int result = 0;
	for (int a = 1; a < n; ++a)
	{
		int const a_square = a * a;
		for (int b = a; a_square + b * b < n_square; ++b)
		{
			int const b_square = b * b;

			int h = 1;
			int h_square = h * h;
			int candidate = a_square + b_square + h_square;
			// Rounding down:
			int root = (int)std::sqrt(double(candidate));
			// Ensure invariant:
			while (root * root < candidate)
				++root;
			// Invariant: root * root >= candidate
			while (candidate <= n_square)
			{
				// Instead of is_square(candidate) we check if candidate hit root * root.
				result += candidate == root * root;

				// Prepare for next iteration.
				++h;
				h_square = h * h;
				candidate = a_square + b_square + h_square;
				// Ensure invariant:
				while (root * root < candidate)
					++root;
			}
		}
	}
	return result;
}

/*
 * Still O(n^3), but lower constant factor.
 */
static int compute_answer_faster(int n)
{
	DBG(std::cout << "called compute_answer_faster(" << n << ")\n");
	int const n_square = n * n;

	// We iterate over a <= b <= h and if solution is found, we need to take into account that h does not have any
	// constraint.
	int result = 0;
	for (int a = 1; a < n; ++a)
	{
		int const a_square = a * a;
		for (int b = a; a_square + b * b < n_square; ++b)
		{
			int const b_square = b * b;
			int const start_point = a_square + b_square + b_square; // the last term is start point for h squared
			int diagonal = (int)std::sqrt(double(start_point));
			if (diagonal * diagonal < start_point)
				++diagonal;
			// main loop:
			while (diagonal * diagonal <= n_square)
			{
				int const maybe_h_square = diagonal * diagonal - a_square - b_square;
				int const h = (int)std::sqrt(double(maybe_h_square));
				if (h * h == maybe_h_square)
				{
					if (a < b)
					{
						if (b < h)
						{
							result += 3;
						}
						else
						{
							// b == h
							result += 2;
						}
					}
					else
					{
						// a == b
						if (b < h)
						{
							result += 2;
						}
						else
						{
							// b == h
							result += 1;
						}
					}
				}
				// try next square
				++diagonal;
			}
		}
	}
	return result;
}

int main()
{
	std::ios_base::sync_with_stdio(false);
	std::cin.tie(NULL);

	//test_is_square(n);

	int n;
	std::cin >> n;
	std::cout << compute_answer_faster(n) << '\n';
}