「清华集训 2017」小 Y 和恐怖的奴隶主

题目描述

“A fight? Count me in!” 要打架了,算我一个。
“Everyone, get in here!” 所有人,都过来!

小Y是一个喜欢玩游戏的 OIer。一天,她正在玩一款游戏,要打一个 Boss。

虽然这个 Boss 有 $10^{100}$ 点生命值,但它只带了一个随从——一个只有 $m$ 点生命值的“恐怖的奴隶主”。

这个“恐怖的奴隶主”有一个特殊的技能:每当它被扣减生命值但没有死亡(死亡即生命值 $\leq 0$),且 Boss 的随从数量小于上限 $k$,便会召唤一个新的具有 $m$ 点生命值的“恐怖的奴隶主”。

现在小Y可以进行 $n$ 次攻击,每次攻击时,会从 Boss 以及 Boss 的所有随从中的等概率随机选择一个,并扣减 $1$ 点生命值,她想知道进行 $n$ 次攻击后扣减 Boss 的生命值点数的期望。为了避免精度误差,你的答案需要对 $998244353$ 取模。

输入格式

输入第一行包含三个正整数 $T,m,k$ 表示询问组数,$m,k$ 的含义见题目描述。

接下来 $T$ 行,每行包含一个正整数 $n$,表示询问进行 $n$ 次攻击后扣减 Boss 的生命值点数的期望。

输出格式

输出共 $T$ 行,对于每个询问输出一行一个非负整数,表示该询问的答案对 $998244353$ 取模的结果。

可以证明,所求期望一定是一个有理数,设其为$p/q (\mathrm{gcd}(p,q) = 1)$ 那么你输出的数 $x$ 要满足 $p \equiv qx \mod 998244353$。

样例

输入
3 2 6
1
2
3
输出
499122177
415935148
471393168

数据范围与提示

在所有测试点中, $1\leq T \leq 1000, 1 \leq n \leq 10^{18}, 1 \leq m \leq 3, 1 \leq k \leq 8$

各个测试点的分值和数据范围如下:

测试点编号分值$T=$$n \leq $$m=$$k=$
13101011
2828
37$10^{18}$3
4127
5302035
6105006
7102007
851000
9101008
105500

题解

20pts

可以直接$DP$
[BZOJ 4832] [Lydsy2017年4月月赛]抵制克苏恩

100pts

发现每次转移都是一样的
然后就可以矩阵快速幂了。
然而并不能过 2333

时间复杂度为$O(Tcnt^3log(n))$, $cnt$不会超过200。
我们还需要优化

我们可以将转移矩阵的 $k$ 次方全部预处理出来, 然后用一个 $1 \times cnt$ 的矩阵不断的去乘, 这样时间复杂度会被降到 $O(cnt^3log(n) + Tcnt^2log(n))$
就可以过了。

好像是道水题

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline long long read()
{
    long long x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
const int MOD = 998244353;
struct Matrix
{
    int a[200][200];
    int n, m;
    Matrix(int _n = 0, int _m = 0)
    {
        n = _n, m = _m;
        memset (a, 0, sizeof (a));
    }
    Matrix operator * (const Matrix &b) const 
    {
        Matrix ans(n, b.m);
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= m; j++)
                for (int k = 1; k <= b.m; k++)
                    ans.a[i][k] = (ans.a[i][k] + 1ll * a[i][j] * b.a[j][k]) % MOD;
        return ans;
    }
}A[64];
long long pow_mod(long long a, int b)
{
    long long ans = 1;
    while (b)
    {
        if (b & 1) ans = ans * a % MOD;
        b >>= 1;
        a = a * a % MOD;
    }
    return ans;
}
int cnt[10][10][10], Index, IDans;
int main()
{
    int T = read(), m = read(), k = read();
    for (int i = 0; i <= k; i++)
        for (int j = 0; j <= (m == 1 ? 0 : k); j++)
            for (int l = 0; l <= ((m == 1 || m == 2) ? 0 : k); l++) if (i + j + l <= k)
                cnt[i][j][l] = ++Index;
    IDans = ++Index;
    if (m == 1)
    {
        for (int i = 0; i <= k; i++)
        {
            if (i > 0) (A[0].a[cnt[i][0][0]][cnt[i - 1][0][0]] += i * pow_mod(i + 1, MOD - 2) % MOD) %= MOD;
            (A[0].a[cnt[i][0][0]][cnt[i][0][0]] += pow_mod(i + 1, MOD - 2)) %= MOD;
            (A[0].a[cnt[i][0][0]][IDans] += pow_mod(i + 1, MOD - 2)) %= MOD;
        }
        A[0].a[IDans][IDans] = 1;
    }
    else if (m == 2)
    {
        for (int j = 0; j <= k; j++)
            for (int l = 0; l <= k; l++) if (l + j <= k)
            {
                if (l >= 1 && j + l + 1 <= k) 
                    (A[0].a[cnt[j][l][0]][cnt[j + 1][l][0]] += pow_mod(j + l + 1, MOD - 2) * l % MOD) %= MOD;
                if (l >= 1 && j + l + 1 > k)
                    (A[0].a[cnt[j][l][0]][cnt[j + 1][l - 1][0]] += pow_mod(j + l + 1, MOD - 2) * l % MOD) %= MOD;
                if (j >= 1)
                    (A[0].a[cnt[j][l][0]][cnt[j - 1][l][0]] += pow_mod(j + l + 1, MOD - 2) * j % MOD) %= MOD;
                (A[0].a[cnt[j][l][0]][cnt[j][l][0]] += pow_mod(j + l + 1, MOD - 2)) %= MOD;
                (A[0].a[cnt[j][l][0]][IDans] += pow_mod(j + l + 1, MOD - 2)) %= MOD;
            }
        A[0].a[IDans][IDans] = 1;
    }
    else
    {
        for (int i = 0; i <= k; i++)
            for (int j = 0; j <= k; j++)
                for (int l = 0; l <= k; l++) if (i + j + l <= k)
                {
                    if (i > 0)
                        (A[0].a[cnt[i][j][l]][cnt[i - 1][j][l]] += i * pow_mod(i + j + l + 1, MOD - 2) % MOD) %= MOD;
                    if (j > 0)
                    {
                        if (i + j + l + 1 <= k) 
                            (A[0].a[cnt[i][j][l]][cnt[i + 1][j - 1][l + 1]] += j * pow_mod(i + j + l + 1, MOD - 2) % MOD) %= MOD;
                        else 
                            (A[0].a[cnt[i][j][l]][cnt[i + 1][j - 1][l]] += j * pow_mod(i + j + l + 1, MOD - 2) % MOD) %= MOD;
                    }
                    if (l > 0)
                    {
                        if (i + j + l + 1 <= k) 
                            (A[0].a[cnt[i][j][l]][cnt[i][j + 1][l]] += l * pow_mod(i + j + l + 1, MOD - 2) % MOD) %= MOD;
                        else 
                            (A[0].a[cnt[i][j][l]][cnt[i][j + 1][l - 1]] += l * pow_mod(i + j + l + 1, MOD - 2) % MOD) %= MOD;
                    }
                    (A[0].a[cnt[i][j][l]][cnt[i][j][l]] += pow_mod(i + j + l + 1, MOD - 2)) %= MOD;
                    (A[0].a[cnt[i][j][l]][IDans] += pow_mod(i + j + l + 1, MOD - 2)) %= MOD;
                }
        A[0].a[IDans][IDans] = 1;
    }
    A[0].n = A[0].m = Index;
    for (int i = 1; i <= 63; i++)
        A[i] = A[i - 1] * A[i - 1];
    while (T--)
    {
        Matrix B(1, Index);
        B.a[1][cnt[m == 1][m == 2][m == 3]] = 1;
        long long n = read();
        for (int i = 0; i <= 63; i++)
            if (n & (1ll << i))
                B = B * A[i];
        printf ("%d\n", B.a[1][IDans]);
    }
}
本文作者 : NekoMio
知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
本文链接 : https://www.nekomio.com/2018/04/12/144/
上一篇
下一篇