1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
// Szymon Rusiecki (07.12.2021)
#include <bits/stdc++.h>

long long mod = 1000000007;

class poly {
public:
	int size = 0;
    long long *tab = new long long [size];

    poly(long long n, long long k);
	~poly();
	poly &operator=(poly &other);
	poly &operator*(poly &other);
};

template<typename T>
T power(T a, int b);

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    long long n, k;
    std::cin >> n >> k;

    if (n == 1) {
        std::cout << 0;
        return 0;
    }
    else if (k == 1) {
        std::cout << 1;
        return 0;
    }

    long long output = power(k, n);

	poly core(n, 1);

	long long pol_silnia = (k * (k - 1)) % mod;
	long long mnoznik = k - 2;

	for (int i = 1; i < std::min(n, k); ++i) {
		poly temp(n - i, i);
		--core.size;

		if (i != 1)
			core = core * temp;

		temp = temp * core;
		output -= (pol_silnia * temp.tab[n - i - 1] + mod) % mod;
		output = (output + mod) % mod;
		pol_silnia *= mnoznik--;
		pol_silnia %= mod;
	}

	std::cout << output;
	
    return 0;
}

template<typename T>
T power(T a, int b) {
    if (b == 1)
        return a;
    else if (b % 2 == 0) {
        T output = power(a, b / 2);
        return (output * output) % mod;
    }
    else return (a * power(a, b - 1)) % mod;
}

poly::poly(long long n, long long k) {
    delete [] tab;
	size = n + 1;
    tab = new long long [size];

    tab[0] = 1;
    for (int i = 1; i < size; ++i)
        tab[i] = (tab[i - 1] * k + mod) % mod;
}
poly::~poly() {
	delete [] tab;
}
poly &poly::operator=(poly &other) {
	if (this == &other)
		return *this;
	
	for (int i = 0; i < size; ++i)
		tab[i] = other.tab[i];

	return *this;
}
poly &poly::operator*(poly &other) {
	poly *temp = new poly(size - 1, 0);
	temp->tab[0] = 0;
	
	for (int i = 0; i < size; ++i)
		for (int j = 0; j + i < size; ++j)
			temp->tab[i + j] += (tab[i] * other.tab[j]) % mod;

	for (int i = 0; i < size; ++i)
		temp->tab[i] %= mod;

	return *temp;
}