#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define ST first
#define ND second
#define PB push_back
#define SIZE(a) (int(a.size()))
const int mod = 1e9+7;
const int maxn = 30000;
vector<pii> v[maxn];
vector<vector<int>> res;
int pp;
int st;
void dfs(int a, int p, vector<int> &u) {
if(a > st) {
res.PB(u);
}
for(pii x : v[a]) {
int b = x.ST;
if(b == p) {
continue;
}
int c = x.ND;
auto w = u;
w[c]++;
dfs(b, a, w);
}
}
int main() {
ios_base::sync_with_stdio(0);
int n, k;
cin >> n >> k;
for(int i=0; i < n-1; i++) {
int a, b, p;
cin >> a >> b >> p;
p--;
v[a].PB({b, p});
v[b].PB({a, p});
pp = max(p, pp);
}
for(int i=1; i <= n; i++) {
auto u = vector<int> (pp+1);
st = i;
dfs(i, -1, u);
}
for(auto &u : res) {
reverse(u.begin(), u.end());
}
sort(res.begin(), res.end());
// for(auto &u : res) {
// for(int b : u) {
// cerr << b << " ";
// }
// cerr << "\n";
// }
ll c=0;
ll m=n;
k--;
for(int i=SIZE(res[k])-1; i >= 0; i--) {
// cerr << res[k][i] << " ";
c = (c+res[k][i]*m)%mod;
m = (m*n)%mod;
}
cout << (c+mod)%mod << "\n";
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> pii; typedef pair<ll, ll> pll; #define ST first #define ND second #define PB push_back #define SIZE(a) (int(a.size())) const int mod = 1e9+7; const int maxn = 30000; vector<pii> v[maxn]; vector<vector<int>> res; int pp; int st; void dfs(int a, int p, vector<int> &u) { if(a > st) { res.PB(u); } for(pii x : v[a]) { int b = x.ST; if(b == p) { continue; } int c = x.ND; auto w = u; w[c]++; dfs(b, a, w); } } int main() { ios_base::sync_with_stdio(0); int n, k; cin >> n >> k; for(int i=0; i < n-1; i++) { int a, b, p; cin >> a >> b >> p; p--; v[a].PB({b, p}); v[b].PB({a, p}); pp = max(p, pp); } for(int i=1; i <= n; i++) { auto u = vector<int> (pp+1); st = i; dfs(i, -1, u); } for(auto &u : res) { reverse(u.begin(), u.end()); } sort(res.begin(), res.end()); // for(auto &u : res) { // for(int b : u) { // cerr << b << " "; // } // cerr << "\n"; // } ll c=0; ll m=n; k--; for(int i=SIZE(res[k])-1; i >= 0; i--) { // cerr << res[k][i] << " "; c = (c+res[k][i]*m)%mod; m = (m*n)%mod; } cout << (c+mod)%mod << "\n"; } |
English