#include <bits/stdc++.h>
using namespace std;
const int N = 200005, M = 1000005;
int n, m, maxQuery;
long long t[N];
int query[N];
bool done[N];
long long totalCost[M];
vector<int> process[M];
int rep[N];
int size[N];
long long sumDist[N];
long long costSoFar[N];
int mergeTime[N];
int last[N];
int components;
long long cost;
long long totalPairs;
long long ceil(long long a, long long b) {
return (a + b - 1) / b;
}
int find(int w) {
return rep[w] = (w == rep[w] ? rep[w] : find(rep[w]));
}
void Union(int a, int b) {
a = find(a);
b = find(b);
if (a == b) {
return;
}
rep[b] = a;
size[a] += size[b];
sumDist[a] += sumDist[b] + size[b] * (t[b] - t[a]);
last[a] = last[b];
}
void subtractCurrentCost(int w, int joinTime) {
cost -= costSoFar[w] + (long long)(joinTime - mergeTime[w]) * size[w] * (size[w] - 1) / 2;
}
void join(int pos, int joinTime) {
int prev = pos - 1;
int a = find(prev);
int b = find(pos);
subtractCurrentCost(a, joinTime);
subtractCurrentCost(b, joinTime);
totalPairs += (long long)size[a] * size[b];
Union(a, b);
mergeTime[a] = joinTime;
costSoFar[a] = (long long)joinTime * size[a] * (size[a] - 1) / 2 - sumDist[a];
cost += costSoFar[a];
components--;
done[pos] = true;
pos = last[a];
int next = pos + 1;
if (next <= n && !done[next]) {
long long dist = t[next] - joinTime - ((long long)joinTime * (size[a] - 1) + t[a]);
long long timeOfMerge = joinTime + max(0LL, ceil(dist, size[a]));
// cerr << timeOfMerge << endl;
if (timeOfMerge <= maxQuery) {
process[timeOfMerge].push_back(next);
}
// cerr << "Next merge prediction = " << pos << " " << next << " " << dist << " " << timeOfMerge << endl;
}
}
void simulate() {
size[0] = 1;
for (int i = 1; i <= n; i++) {
rep[i] = i;
last[i] = i;
size[i] = 1;
if (t[i] - t[i - 1] <= maxQuery) {
process[t[i] - t[i - 1]].push_back(i);
}
costSoFar[i] = 0;
mergeTime[i] = 0;
}
components = n;
for (int i = 0; i <= maxQuery; i++) {
cost += totalPairs;
for (int j = 0; j < process[i].size(); j++) {
if (done[process[i][j]]) continue;
// cerr << i << " " << process[i][j] << endl;
join(process[i][j], i);
}
totalCost[i] = cost;
}
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%lld", &t[i]);
}
for (int i = 1; i <= m; i++) {
scanf("%d", &query[i]);
maxQuery = max(maxQuery, query[i]);
}
simulate();
for (int i = 1; i <= m; i++) {
printf("%lld\n", totalCost[query[i]]);
}
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | #include <bits/stdc++.h> using namespace std; const int N = 200005, M = 1000005; int n, m, maxQuery; long long t[N]; int query[N]; bool done[N]; long long totalCost[M]; vector<int> process[M]; int rep[N]; int size[N]; long long sumDist[N]; long long costSoFar[N]; int mergeTime[N]; int last[N]; int components; long long cost; long long totalPairs; long long ceil(long long a, long long b) { return (a + b - 1) / b; } int find(int w) { return rep[w] = (w == rep[w] ? rep[w] : find(rep[w])); } void Union(int a, int b) { a = find(a); b = find(b); if (a == b) { return; } rep[b] = a; size[a] += size[b]; sumDist[a] += sumDist[b] + size[b] * (t[b] - t[a]); last[a] = last[b]; } void subtractCurrentCost(int w, int joinTime) { cost -= costSoFar[w] + (long long)(joinTime - mergeTime[w]) * size[w] * (size[w] - 1) / 2; } void join(int pos, int joinTime) { int prev = pos - 1; int a = find(prev); int b = find(pos); subtractCurrentCost(a, joinTime); subtractCurrentCost(b, joinTime); totalPairs += (long long)size[a] * size[b]; Union(a, b); mergeTime[a] = joinTime; costSoFar[a] = (long long)joinTime * size[a] * (size[a] - 1) / 2 - sumDist[a]; cost += costSoFar[a]; components--; done[pos] = true; pos = last[a]; int next = pos + 1; if (next <= n && !done[next]) { long long dist = t[next] - joinTime - ((long long)joinTime * (size[a] - 1) + t[a]); long long timeOfMerge = joinTime + max(0LL, ceil(dist, size[a])); // cerr << timeOfMerge << endl; if (timeOfMerge <= maxQuery) { process[timeOfMerge].push_back(next); } // cerr << "Next merge prediction = " << pos << " " << next << " " << dist << " " << timeOfMerge << endl; } } void simulate() { size[0] = 1; for (int i = 1; i <= n; i++) { rep[i] = i; last[i] = i; size[i] = 1; if (t[i] - t[i - 1] <= maxQuery) { process[t[i] - t[i - 1]].push_back(i); } costSoFar[i] = 0; mergeTime[i] = 0; } components = n; for (int i = 0; i <= maxQuery; i++) { cost += totalPairs; for (int j = 0; j < process[i].size(); j++) { if (done[process[i][j]]) continue; // cerr << i << " " << process[i][j] << endl; join(process[i][j], i); } totalCost[i] = cost; } } int main() { scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &t[i]); } for (int i = 1; i <= m; i++) { scanf("%d", &query[i]); maxQuery = max(maxQuery, query[i]); } simulate(); for (int i = 1; i <= m; i++) { printf("%lld\n", totalCost[query[i]]); } return 0; } |
English