#include <bits/stdc++.h> using namespace std; const int N = 200005, M = 1000005; int n, m, maxQuery; long long t[N]; int query[N]; bool done[N]; long long totalCost[M]; vector<int> process[M]; int rep[N]; int size[N]; long long sumDist[N]; long long costSoFar[N]; int mergeTime[N]; int last[N]; int components; long long cost; long long totalPairs; long long ceil(long long a, long long b) { return (a + b - 1) / b; } int find(int w) { return rep[w] = (w == rep[w] ? rep[w] : find(rep[w])); } void Union(int a, int b) { a = find(a); b = find(b); if (a == b) { return; } rep[b] = a; size[a] += size[b]; sumDist[a] += sumDist[b] + size[b] * (t[b] - t[a]); last[a] = last[b]; } void subtractCurrentCost(int w, int joinTime) { cost -= costSoFar[w] + (long long)(joinTime - mergeTime[w]) * size[w] * (size[w] - 1) / 2; } void join(int pos, int joinTime) { int prev = pos - 1; int a = find(prev); int b = find(pos); subtractCurrentCost(a, joinTime); subtractCurrentCost(b, joinTime); totalPairs += (long long)size[a] * size[b]; Union(a, b); mergeTime[a] = joinTime; costSoFar[a] = (long long)joinTime * size[a] * (size[a] - 1) / 2 - sumDist[a]; cost += costSoFar[a]; components--; done[pos] = true; pos = last[a]; int next = pos + 1; if (next <= n && !done[next]) { long long dist = t[next] - joinTime - ((long long)joinTime * (size[a] - 1) + t[a]); long long timeOfMerge = joinTime + max(0LL, ceil(dist, size[a])); // cerr << timeOfMerge << endl; if (timeOfMerge <= maxQuery) { process[timeOfMerge].push_back(next); } // cerr << "Next merge prediction = " << pos << " " << next << " " << dist << " " << timeOfMerge << endl; } } void simulate() { size[0] = 1; for (int i = 1; i <= n; i++) { rep[i] = i; last[i] = i; size[i] = 1; if (t[i] - t[i - 1] <= maxQuery) { process[t[i] - t[i - 1]].push_back(i); } costSoFar[i] = 0; mergeTime[i] = 0; } components = n; for (int i = 0; i <= maxQuery; i++) { cost += totalPairs; for (int j = 0; j < process[i].size(); j++) { if (done[process[i][j]]) continue; // cerr << i << " " << process[i][j] << endl; join(process[i][j], i); } totalCost[i] = cost; } } int main() { scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &t[i]); } for (int i = 1; i <= m; i++) { scanf("%d", &query[i]); maxQuery = max(maxQuery, query[i]); } simulate(); for (int i = 1; i <= m; i++) { printf("%lld\n", totalCost[query[i]]); } return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | #include <bits/stdc++.h> using namespace std; const int N = 200005, M = 1000005; int n, m, maxQuery; long long t[N]; int query[N]; bool done[N]; long long totalCost[M]; vector<int> process[M]; int rep[N]; int size[N]; long long sumDist[N]; long long costSoFar[N]; int mergeTime[N]; int last[N]; int components; long long cost; long long totalPairs; long long ceil(long long a, long long b) { return (a + b - 1) / b; } int find(int w) { return rep[w] = (w == rep[w] ? rep[w] : find(rep[w])); } void Union(int a, int b) { a = find(a); b = find(b); if (a == b) { return; } rep[b] = a; size[a] += size[b]; sumDist[a] += sumDist[b] + size[b] * (t[b] - t[a]); last[a] = last[b]; } void subtractCurrentCost(int w, int joinTime) { cost -= costSoFar[w] + (long long)(joinTime - mergeTime[w]) * size[w] * (size[w] - 1) / 2; } void join(int pos, int joinTime) { int prev = pos - 1; int a = find(prev); int b = find(pos); subtractCurrentCost(a, joinTime); subtractCurrentCost(b, joinTime); totalPairs += (long long)size[a] * size[b]; Union(a, b); mergeTime[a] = joinTime; costSoFar[a] = (long long)joinTime * size[a] * (size[a] - 1) / 2 - sumDist[a]; cost += costSoFar[a]; components--; done[pos] = true; pos = last[a]; int next = pos + 1; if (next <= n && !done[next]) { long long dist = t[next] - joinTime - ((long long)joinTime * (size[a] - 1) + t[a]); long long timeOfMerge = joinTime + max(0LL, ceil(dist, size[a])); // cerr << timeOfMerge << endl; if (timeOfMerge <= maxQuery) { process[timeOfMerge].push_back(next); } // cerr << "Next merge prediction = " << pos << " " << next << " " << dist << " " << timeOfMerge << endl; } } void simulate() { size[0] = 1; for (int i = 1; i <= n; i++) { rep[i] = i; last[i] = i; size[i] = 1; if (t[i] - t[i - 1] <= maxQuery) { process[t[i] - t[i - 1]].push_back(i); } costSoFar[i] = 0; mergeTime[i] = 0; } components = n; for (int i = 0; i <= maxQuery; i++) { cost += totalPairs; for (int j = 0; j < process[i].size(); j++) { if (done[process[i][j]]) continue; // cerr << i << " " << process[i][j] << endl; join(process[i][j], i); } totalCost[i] = cost; } } int main() { scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &t[i]); } for (int i = 1; i <= m; i++) { scanf("%d", &query[i]); maxQuery = max(maxQuery, query[i]); } simulate(); for (int i = 1; i <= m; i++) { printf("%lld\n", totalCost[query[i]]); } return 0; } |