#include "message.h" #include "teatr.h" #include "bits/stdc++.h" using namespace std; const int R = 1000 * 1000 + 5; vector<int> tmp(R); long long mergeSort(int l, int r, vector<int>& a) { if (l >= r) { return 0; } int m = (l + r) / 2; long long rs = mergeSort(l, m, a) + mergeSort(m + 1, r, a); int iter = 0; int i = l, j = m + 1; while (i <= m && j <= r) { if (a[i] > a[j]) { rs += m - i + 1; tmp[iter++] = a[j++]; } else { tmp[iter++] = a[i++]; } } while (i <= m) tmp[iter++] = a[i++]; while (j <= r) tmp[iter++] = a[j++]; for (int i = l; i <= r; ++i) a[i] = tmp[i - l]; return rs; } long long calcInversions(vector<int> a) { return mergeSort(0, (int) a.size() - 1, a); } int id, nodes; int n; void solveOneNode() { if (id != 0) { return; } vector<int> a(n); for (int i = 0; i < n; ++i) { a[i] = GetElement(i); } cout << calcInversions(a) << '\n'; } void solveDistributed() { int block_len = (n + nodes - 1) / nodes; int l = id * block_len, r = min((id + 1) * block_len, n); if (l >= n) { return; } vector<int> a; for (int i = l; i < r; ++i) { a.push_back(GetElement(i)); } long long rs = calcInversions(a); if (id != 0) { PutLL(0, rs); Send(0); return; } vector<int> cnt(R), c(R); for (int e: a) { ++cnt[e]; } for (int i = R - 2; i >= 0; --i) cnt[i] += cnt[i + 1]; for (int i = 1; i < nodes; ++i) { int l = i * block_len, r = min((i + 1) * block_len, n); if (l >= n) { break; } for (int j = l; j < r; ++j) { int e = GetElement(j); rs += cnt[e + 1]; ++c[e]; } for (int j = R - 2; j >= 0; --j) { c[j] += c[j + 1]; c[j + 1] = 0; cnt[j] += c[j]; } c[0] = 0; } for (int i = 1; i < nodes; ++i) { if (i * block_len >= n) break; int sender = Receive(-1); rs += GetLL(sender); } cout << rs << '\n'; } void test() { if (id == 0) { long long sum = 0; for (int i = 0; i < n; ++i) { sum += GetElement(i); } cout << sum << endl; } exit(0); } int main() { id = MyNodeId(); nodes = NumberOfNodes(); n = GetN(); //test(); if (n < R) { solveOneNode(); } else { solveDistributed(); } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | #include "message.h" #include "teatr.h" #include "bits/stdc++.h" using namespace std; const int R = 1000 * 1000 + 5; vector<int> tmp(R); long long mergeSort(int l, int r, vector<int>& a) { if (l >= r) { return 0; } int m = (l + r) / 2; long long rs = mergeSort(l, m, a) + mergeSort(m + 1, r, a); int iter = 0; int i = l, j = m + 1; while (i <= m && j <= r) { if (a[i] > a[j]) { rs += m - i + 1; tmp[iter++] = a[j++]; } else { tmp[iter++] = a[i++]; } } while (i <= m) tmp[iter++] = a[i++]; while (j <= r) tmp[iter++] = a[j++]; for (int i = l; i <= r; ++i) a[i] = tmp[i - l]; return rs; } long long calcInversions(vector<int> a) { return mergeSort(0, (int) a.size() - 1, a); } int id, nodes; int n; void solveOneNode() { if (id != 0) { return; } vector<int> a(n); for (int i = 0; i < n; ++i) { a[i] = GetElement(i); } cout << calcInversions(a) << '\n'; } void solveDistributed() { int block_len = (n + nodes - 1) / nodes; int l = id * block_len, r = min((id + 1) * block_len, n); if (l >= n) { return; } vector<int> a; for (int i = l; i < r; ++i) { a.push_back(GetElement(i)); } long long rs = calcInversions(a); if (id != 0) { PutLL(0, rs); Send(0); return; } vector<int> cnt(R), c(R); for (int e: a) { ++cnt[e]; } for (int i = R - 2; i >= 0; --i) cnt[i] += cnt[i + 1]; for (int i = 1; i < nodes; ++i) { int l = i * block_len, r = min((i + 1) * block_len, n); if (l >= n) { break; } for (int j = l; j < r; ++j) { int e = GetElement(j); rs += cnt[e + 1]; ++c[e]; } for (int j = R - 2; j >= 0; --j) { c[j] += c[j + 1]; c[j + 1] = 0; cnt[j] += c[j]; } c[0] = 0; } for (int i = 1; i < nodes; ++i) { if (i * block_len >= n) break; int sender = Receive(-1); rs += GetLL(sender); } cout << rs << '\n'; } void test() { if (id == 0) { long long sum = 0; for (int i = 0; i < n; ++i) { sum += GetElement(i); } cout << sum << endl; } exit(0); } int main() { id = MyNodeId(); nodes = NumberOfNodes(); n = GetN(); //test(); if (n < R) { solveOneNode(); } else { solveDistributed(); } } |