#include <bits/stdc++.h> using namespace std; const int MX=300300,md=1000000007; int n,q,i,j,x,y,z,sw,cur,a[MX],st[MX],cnt[MX],k[MX],d[MX],dst[MX][2],r[2]; vector<int> g[MX]; bool cmp(int i, int j) { return dst[i][0]+dst[i][1]>dst[j][0]+dst[j][1]; } void dfs(int i, int p, int d, int w) { dst[i][w]=d; for (int j: g[i]) if (j!=p) dfs(j,i,d+1,w); } int main() { scanf("%d%d",&n,&q); cnt[0]=st[1]=1; for (i=1; i<n; i++) { scanf("%d",&a[i]); cnt[i]=cnt[i-1]*a[i]; st[i+1]=st[i]+cnt[i]; for (x=st[i-1], y=st[i]; x<st[i]; x++) for (j=0; j<a[i]; j++, y++) { g[x].push_back(y); g[y].push_back(x); k[y]=y; d[y]=i; } } for (i=0; i<st[n]; i++) if (g[i].size()>1) swap(g[i][0],g[i][int(g[i].size())-1]); while (q--) { scanf("%d%d%d",&i,&j,&z); --i; --j; --z; if (i<j) { swap(i,j); sw=1; } else sw=0; for (cur=0; d[cur]<z; cur=g[cur][0]); if (i==z) x=cur; else for (x=g[cur][1]; d[x]<i; x=g[x][0]); for (y=cur; d[y]<j; y=g[y][0]); dfs(x,-1,0,sw); dfs(y,-1,0,(sw^1)); sort(k,k+st[n],cmp); r[0]=r[1]=0; for (i=0; i<st[n]; i++) { r[i&1]+=dst[k[i]][i&1]; if (r[i&1]>=md) r[i&1]-=md; } printf("%d\n",(r[0]+md-r[1])%md); } return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | #include <bits/stdc++.h> using namespace std; const int MX=300300,md=1000000007; int n,q,i,j,x,y,z,sw,cur,a[MX],st[MX],cnt[MX],k[MX],d[MX],dst[MX][2],r[2]; vector<int> g[MX]; bool cmp(int i, int j) { return dst[i][0]+dst[i][1]>dst[j][0]+dst[j][1]; } void dfs(int i, int p, int d, int w) { dst[i][w]=d; for (int j: g[i]) if (j!=p) dfs(j,i,d+1,w); } int main() { scanf("%d%d",&n,&q); cnt[0]=st[1]=1; for (i=1; i<n; i++) { scanf("%d",&a[i]); cnt[i]=cnt[i-1]*a[i]; st[i+1]=st[i]+cnt[i]; for (x=st[i-1], y=st[i]; x<st[i]; x++) for (j=0; j<a[i]; j++, y++) { g[x].push_back(y); g[y].push_back(x); k[y]=y; d[y]=i; } } for (i=0; i<st[n]; i++) if (g[i].size()>1) swap(g[i][0],g[i][int(g[i].size())-1]); while (q--) { scanf("%d%d%d",&i,&j,&z); --i; --j; --z; if (i<j) { swap(i,j); sw=1; } else sw=0; for (cur=0; d[cur]<z; cur=g[cur][0]); if (i==z) x=cur; else for (x=g[cur][1]; d[x]<i; x=g[x][0]); for (y=cur; d[y]<j; y=g[y][0]); dfs(x,-1,0,sw); dfs(y,-1,0,(sw^1)); sort(k,k+st[n],cmp); r[0]=r[1]=0; for (i=0; i<st[n]; i++) { r[i&1]+=dst[k[i]][i&1]; if (r[i&1]>=md) r[i&1]-=md; } printf("%d\n",(r[0]+md-r[1])%md); } return 0; } |