1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "message.h"
#include "teatr.h"

#include <algorithm>

#include <ext/pb_ds/assoc_container.hpp>

typedef __gnu_pbds::tree<
    int,
    __gnu_pbds::null_type,
    std::less_equal<int>,
    __gnu_pbds::rb_tree_tag,
    __gnu_pbds::tree_order_statistics_node_update>
    interval_tree;

constexpr auto MAX = 1000001;

void countSmaller(long long int* smaller, const int fromPosition, const int toPosition)
{
    int* elementsCount = new int[MAX];
    for (int i = 0; i < MAX; ++i) {
        elementsCount[i] = 0;
    }
    for (int i = fromPosition; i <= toPosition; i++) {
        elementsCount[GetElement(i)]++;
    }
    smaller[0] = 0;
    for (int i = 1; i < MAX; ++i) {
        smaller[i] = smaller[i - 1] + elementsCount[i];
    }
    delete[] elementsCount;
}

long long int calculateConflicts(std::pair<int, int> range)
{
    const int fromPosition = range.first;
    const int toPosition = range.second;
    const int n = GetN();

    if (fromPosition >= n) {
        return 0;
    }

    long long int* smaller = new long long int[MAX];
    countSmaller(smaller, toPosition + 1, n - 1);

    long long int sum = 0;
    interval_tree tree;
    for (int i = toPosition; i >= fromPosition; --i) {
        const int element = GetElement(i);
        sum += tree.order_of_key(element);
        sum += smaller[element - 1];
        tree.insert(element);
    }
    delete[] smaller;
    return sum;
}

std::pair<int, int> getRange(int id)
{
    const int n = GetN();
    const int nodes = NumberOfNodes();
    const int positionsPerNode = n % nodes == 0 ? n / nodes : n / nodes + 1;
    const int start = id * positionsPerNode;
    const int stop = std::min(start + positionsPerNode - 1, n - 1);
    return { start, stop };
}

int main()
{
    if (MyNodeId() == 0) {
        const int nodes = NumberOfNodes();
        long long int sum = calculateConflicts(getRange(0));
        for (int i = 1; i < nodes; ++i) {
            Receive(i);
            sum += GetLL(i);
        }
        printf("%lld\n", sum);
    } else {
        const long long int sum = calculateConflicts(getRange(MyNodeId()));
        PutLL(0, sum);
        Send(0);
    }
    return 0;
}