#include <bits/stdc++.h>
using namespace std;
int n, k, a, b, p, pr;
int odw[25010], nr;
vector <vector <int> > v;
vector <int> h;
vector <pair <int, int> > g[25010];
long long wyn, mod=1e9+7;
void dfs(int w)
{
odw[w]=nr;
if(w>pr)
{
v.push_back(h);
}
for(int i=0; i<(int)g[w].size(); ++i)
{
if(odw[g[w][i].first]!=nr)
{
h.push_back(g[w][i].second);
dfs(g[w][i].first);
h.pop_back();
}
}
}
long long pot(long long u, long long v)
{
u%=mod;
long long z=1;
while(v)
{
if(v&1)
{
z=(z*u)%mod;
}
u=(u*u)%mod;
v>>=1;
}
return z;
}
int main()
{
scanf("%d%d", &n, &k);
for(int i=1; i<n; ++i)
{
scanf("%d%d%d", &a, &b, &p);
g[a].push_back({b, p});
g[b].push_back({a, p});
}
for(int i=1; i<=n; ++i)
{
++nr;
pr=i;
dfs(i);
}
for(int i=0; i<(int)v.size(); ++i)
{
sort(v[i].begin(), v[i].end());
reverse(v[i].begin(), v[i].end());
}
sort(v.begin(), v.end());
--k;
for(int i=0; i<(int)v[k].size(); ++i)
{
wyn+=pot(n, v[k][i]);
if(wyn>=mod)
{
wyn%=mod;
}
}
printf("%lld\n", wyn);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #include <bits/stdc++.h> using namespace std; int n, k, a, b, p, pr; int odw[25010], nr; vector <vector <int> > v; vector <int> h; vector <pair <int, int> > g[25010]; long long wyn, mod=1e9+7; void dfs(int w) { odw[w]=nr; if(w>pr) { v.push_back(h); } for(int i=0; i<(int)g[w].size(); ++i) { if(odw[g[w][i].first]!=nr) { h.push_back(g[w][i].second); dfs(g[w][i].first); h.pop_back(); } } } long long pot(long long u, long long v) { u%=mod; long long z=1; while(v) { if(v&1) { z=(z*u)%mod; } u=(u*u)%mod; v>>=1; } return z; } int main() { scanf("%d%d", &n, &k); for(int i=1; i<n; ++i) { scanf("%d%d%d", &a, &b, &p); g[a].push_back({b, p}); g[b].push_back({a, p}); } for(int i=1; i<=n; ++i) { ++nr; pr=i; dfs(i); } for(int i=0; i<(int)v.size(); ++i) { sort(v[i].begin(), v[i].end()); reverse(v[i].begin(), v[i].end()); } sort(v.begin(), v.end()); --k; for(int i=0; i<(int)v[k].size(); ++i) { wyn+=pot(n, v[k][i]); if(wyn>=mod) { wyn%=mod; } } printf("%lld\n", wyn); return 0; } |
English