CF932E Team Work


CF932E Team Work

讲道理不难,但是我推不出来(指在模数为998244353的情况下的 \(O(k)\) 做法)我只能搞出 \(O (K \log N)\) 的,但是需要模数为998244353(你要用FFT也行,不保证精度)。

那么考虑 \(O(K^2)\)。显然,你只需要 这题 就行了。 (推导过程容易发现 \(j \gt k\)\(j=0\) 时,值为0。

\[\sum_{i=1}^n \binom n i i^k \\ = \sum_{i=1}^n \binom n i \sum_{j=0}^i {k\brace j}j! \binom i j \\ = \sum_{j=1}^k {k \brace j} j! \sum_{i=j}^n \binom n i \binom i j\\ = \sum_{j=1}^k {k \brace j} j! \sum_{i=j}^n \binom n j \binom {n-j} {i-j} \\ = \sum_{j=1}^k {k \brace j} j! \binom n j \sum_{i=j}^n \binom {n-j} {i-j} \\ = \sum_{j=1}^k {k \brace j} j! \binom n j 2^{n-j} \]

这题就可以 \(O(K^2)\) 完成了。但是我们怎么可以放弃(大雾),我们已知

\[{n \brace k}\\ = \frac{1}{k!} \sum_{i=0}^k (-1)^i \binom ki (k-i)^n \\ = \sum_{i=0}^k \frac{(-1)^i}{i!} \frac {(k-i)^n} {(k-i)!} \]

那么我们可以直接卷一下就变成 \(O (K \log N)\) 了(模数998244353)。

\(O(K)\) 做法请去看min_25博客。

/*
    Name:
    Author: Gensokyo_Alice
    Date:
    Description:
*/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

typedef long long ll;
const ll MAXN = (1LL << 20) + 10, MOD = 1e9+7, INF = 0x3f3f3f3f3f3f3f3f;

ll N, K, ans, fac[MAXN], inv[MAXN];

ll ks(ll, ll); 

namespace subtask2 {
    ll str[MAXN];
    void work() {
        str[1] = 1;
        for (ll i = 2; i <= K; i++)
            for (ll j = i; j >= 1; j--)
                str[j] = (str[j-1] + (j * str[j] % MOD)) % MOD;
        ll now = 1;
        for (ll j = 1; j <= K; j++) {
            ll invnow = inv[j] * fac[j-1] % MOD;
            now = now * (N-j+1) % MOD * invnow % MOD;
            (ans += str[j] * fac[j] % MOD * now % MOD * ks(2, N-j) % MOD) %= MOD;
        }
        printf("%lld\n", ans);
    }
}

int main() {
    scanf("%lld%lld", &N, &K); inv[0] = inv[1] = fac[1] = fac[0] = 1;
    for (ll i = 2; i <= K; i++) fac[i] = fac[i-1] * i % MOD, inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;
    for (ll i = 1; i <= K; i++) inv[i] = inv[i-1] * inv[i] % MOD;
    if (K <= 5000) {
        subtask2::work();
        return 0;
    }
    return 0;
}

ll ks(ll n, ll tim) {
	bool flag = 0;
	if (tim < 0) tim = -tim, flag = 1;
    ll ret = 1;
    while (tim) {
        if (tim & 1) ret = ret * n % MOD;
        n = n * n % MOD;
        tim >>= 1;
    }
    if (flag) return ks(ret, MOD - 2);
    return ret;
}