#include <bits/stdc++.h> using namespace std; int n, k, a, b, p, pr; int odw[25010], nr; vector <vector <int> > v; vector <int> h; vector <pair <int, int> > g[25010]; long long wyn, mod=1e9+7; void dfs(int w) { odw[w]=nr; if(w>pr) { v.push_back(h); } for(int i=0; i<(int)g[w].size(); ++i) { if(odw[g[w][i].first]!=nr) { h.push_back(g[w][i].second); dfs(g[w][i].first); h.pop_back(); } } } long long pot(long long u, long long v) { u%=mod; long long z=1; while(v) { if(v&1) { z=(z*u)%mod; } u=(u*u)%mod; v>>=1; } return z; } int main() { scanf("%d%d", &n, &k); for(int i=1; i<n; ++i) { scanf("%d%d%d", &a, &b, &p); g[a].push_back({b, p}); g[b].push_back({a, p}); } for(int i=1; i<=n; ++i) { ++nr; pr=i; dfs(i); } for(int i=0; i<(int)v.size(); ++i) { sort(v[i].begin(), v[i].end()); reverse(v[i].begin(), v[i].end()); } sort(v.begin(), v.end()); --k; for(int i=0; i<(int)v[k].size(); ++i) { wyn+=pot(n, v[k][i]); if(wyn>=mod) { wyn%=mod; } } printf("%lld\n", wyn); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #include <bits/stdc++.h> using namespace std; int n, k, a, b, p, pr; int odw[25010], nr; vector <vector <int> > v; vector <int> h; vector <pair <int, int> > g[25010]; long long wyn, mod=1e9+7; void dfs(int w) { odw[w]=nr; if(w>pr) { v.push_back(h); } for(int i=0; i<(int)g[w].size(); ++i) { if(odw[g[w][i].first]!=nr) { h.push_back(g[w][i].second); dfs(g[w][i].first); h.pop_back(); } } } long long pot(long long u, long long v) { u%=mod; long long z=1; while(v) { if(v&1) { z=(z*u)%mod; } u=(u*u)%mod; v>>=1; } return z; } int main() { scanf("%d%d", &n, &k); for(int i=1; i<n; ++i) { scanf("%d%d%d", &a, &b, &p); g[a].push_back({b, p}); g[b].push_back({a, p}); } for(int i=1; i<=n; ++i) { ++nr; pr=i; dfs(i); } for(int i=0; i<(int)v.size(); ++i) { sort(v[i].begin(), v[i].end()); reverse(v[i].begin(), v[i].end()); } sort(v.begin(), v.end()); --k; for(int i=0; i<(int)v[k].size(); ++i) { wyn+=pot(n, v[k][i]); if(wyn>=mod) { wyn%=mod; } } printf("%lld\n", wyn); return 0; } |