1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <cstdio>
#include "message.h"
#include "maklib.h"

typedef long long ll;

template <typename T, T (*assocFun)(T, T), T (*getFun)(int), void (*putFun)(int, T)>
T computeInZero(T localVal) {
    int mask = 1;
    
    T val = localVal;
    
    while (mask < NumberOfNodes()) {    
        int partner = MyNodeId() ^ mask;
        if (MyNodeId() % (2 * mask) == mask) {
            putFun(partner, val);
            Send(partner);
        } else if (MyNodeId() % (2 * mask) == 0 && partner < NumberOfNodes()) {
            Receive(partner);
            T extVal = getFun(partner);
            val = assocFun(val, extVal);
        }
        
        mask *= 2;
    }

    return val;
}

/*
template <typename T, T *getFun(int), void *putFun(int, T)>
T propagate(T zeroVal) {
    do {
        mask >>= 1;
    
        if (MyNodeId() % (2 * mask) == mask) {
            int source = MyNodeId() ^ mask;
            Receive(source);
            val = getFun(source);
        } else if (MyNodeId() % (2 * mask) == 0 && MyNodeId() ^ mask < NumberOfNodes()) {
            int target = MyNodeId() ^ mask;
            putFun(target, val);
            Send(target);
        }
    } while (mask > 1);
}
*/

template <typename T, T (*assocFun)(T, T), T (*getFun)(int), void (*putFun)(int, T)>
T computeInAll(T localVal) {
    int mask = 1;
    
    T val = localVal;
    
    while (mask < NumberOfNodes()) {
        if (MyNodeId() % (2 * mask) == mask - 1 && MyNodeId() + mask < NumberOfNodes()) {
            int target = MyNodeId() + mask;
            putFun(target, val);
            Send(target);
        } else if (MyNodeId() % (2 * mask) == 2 * mask - 1) {
            int source = MyNodeId() - mask;
            Receive(source);
            T extVal = getFun(source);
            val = assocFun(extVal, val);
        }
        
        mask *= 2;
    }
        
    while (mask > 2) {
        mask /= 2;
    
        if (MyNodeId() % mask == mask - 1 && MyNodeId() + mask / 2 < NumberOfNodes()) {
            int target = MyNodeId() + mask / 2;
            putFun(target, val);
            Send(target);
        } else if (MyNodeId() % mask == mask / 2 - 1 && MyNodeId() >= mask / 2) {
            int source = MyNodeId() - mask / 2;
            Receive(source);
            T extVal = getFun(source);
            val = assocFun(extVal, val);
        }
    }
    
    return val;
}

ll add(ll a, ll b) {
    return a + b;
}

ll minimum(ll a, ll b) {
    return a < b ? a : b;
}

ll maximum(ll a, ll b) {
    return a < b ? b : a;
}

int main() {
    ll localSum = 0;
    ll minLocalSum = 0;
    ll begin = (ll)Size() * MyNodeId() / NumberOfNodes() + 1;
    ll end = (ll)Size() * (MyNodeId() + 1) / NumberOfNodes() + 1;
    
    for (ll i = begin; i < end; ++i) {
        localSum += ElementAt(i);
        if (localSum < minLocalSum) {
            minLocalSum = localSum;
        }
    }
    
    ll totalSum = computeInAll<ll, add, GetLL, PutLL>(localSum);
    ll prevSum = totalSum - localSum;
    minLocalSum += prevSum;
    
    ll prevMinLocalSum = 0;
    if (MyNodeId() != NumberOfNodes() - 1) {
        PutLL(MyNodeId() + 1, minLocalSum);
        Send(MyNodeId() + 1);
    }
    if (MyNodeId() != 0) {
        Receive(MyNodeId() - 1);
        prevMinLocalSum = GetLL(MyNodeId() - 1);
    }
    
    ll prevMinSum = computeInAll<ll, minimum, GetLL, PutLL>(prevMinLocalSum);
    
    ll sum = prevSum;
    ll minSum = prevMinSum;
    ll maxResult = 0;
    for (ll i = begin; i < end; ++i) {
        sum += ElementAt(i);
        if (sum < minSum) {
            minSum = sum;
        } else if (sum - minSum > maxResult) {
            maxResult = sum - minSum;
        }
    }
    
    ll totalMaxResult = computeInZero<ll, maximum, GetLL, PutLL>(maxResult);
    
    if (MyNodeId() == 0) {
        printf("%lld\n", totalMaxResult);
    }
    
    return 0;
}