BZOJ2510 弱题


题目大意

\(m\) 个球,每个球一开始有一个初始编号,编号为 \(i\) 的球有 \(a_i\) 个。

每次操作等概率 \(\dfrac{1}{m}\) 取出一个球,若这个球标号为 \(k(k < n)\),则将其变为 \(k + 1\),如果这个球标号为 \(n\),则将其标号为 \(1\),然后放回去。

\(k\) 次操作后,每个标号的球的期望个数。

\(\text{Data Range:} n \leq 1000,m\leq 10^8, k \leq 2147483647\)


发现 \(k\) 很大,朴素的方法过不了,那么考虑矩阵快速幂。

首先写朴素的暴力 dp,这个东西很好写,设 \(a_{t,i}\) 表示时间为 \(t\) 的时候编号为 \(i\) 的球的期望个数。

\[a_{i,t} = \dfrac{1}{M} \times a_{i-1,t-1} + (1 - \dfrac{1}{M}) \times a_{i-1,t} \]

对于 \(i = n\) 的时候需要特判。

然后我们对这个东西做一个矩阵快速幂就好了,复杂度 \(n^3 \log k\),炸了。

学会动脑子,你可以写一个矩阵出来,然后你发现他是形如这样的矩阵。

![](file://C:/Users/stu/Documents/Gridea/post-images/1652866791756.PNG)

这种上下两行之间矩阵进行平移的矩阵,我们称其为,循环矩阵。

循环矩阵因为我们只用知道他的第一行就可以推测出他下面的所有行。

于是我们可以只对最上面一行进行矩阵快速幂,然后由他的系数推测出其他行的系数矩阵。

然后如何写循环矩阵呢,手动模拟。

![](file://C:/Users/stu/Documents/Gridea/post-images/1652874719976.PNG)

此时的 \(a_{1,1} = 1 \times 1 +2 \times 3 + 3 \times 3\)

此时的 \(a_{1,2} = 1\times 2 + 2 \times 1 + 3 \times 2\)

此时的 \(a_{1,3} = 1 \times 3 + 2\times 2 + 3 \times 1\)

然后继续往下计算,发现这又是一个循环矩阵qq_emoji: oh

那么只用计算第一行与另一个矩阵相乘。

但是仔细想想,你发现其实因为相乘的矩阵是一个循环矩阵,于是他内部的各个位置都可以由第一行的某个位置表示出来。

经过合理的打表你发现,\(c_i=\sum_{j=1}^n a_j \times b_{((i - j + n) \bmod n )+ 1}\)

然后对这题仍然适用,总复杂度 \(n^2 \log k\)

// 德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱
// 德丽莎的可爱在于德丽莎很可爱,德丽莎为什么很可爱呢,这是因为德丽莎很可爱!
#include 
#define int long long
using namespace std;
inline int read() {
  int x = 0, f = 1;  char ch = getchar();
  while( !isdigit(ch) ) { if(ch == '-') f = -1;  ch = getchar();  }
  while( isdigit(ch) ) {  x = (x << 1) + (x << 3) + (ch ^ 48);  ch = getchar();  }
  return x * f;
}
const int N = 2005;
int n, m, k;
struct mt {
  double f[N];
  mt() {memset(f, 0, sizeof(f)); }
  mt operator * (mt const &x) {
    mt res;
    for (int i = 1; i <= n; i++) {
      for (int j = 1; j <= n; j++) {
        res.f[i] += f[j] * x.f[(i - j + n) % n + 1];
      }
    }
    return res;
  }
}base, ans;
mt ksm(int k,mt a) {
  while (k) { if (k & 1) ans = ans * a; a = a * a; k >>= 1; } 
  return ans;
}
signed main () {
  n = read(), m = read(), k = read();
  for (int i = 1; i <= n; i++) cin >> ans.f[i];
  double A = (double)1.0 * (m - 1) / m;
  double B = (double) 1.0 / m;
  base.f[1] = A; base.f[2] = B;
  mt ans2 = ksm(k, base);
  for (int i = 1; i <= n; i++) printf("%.3lf\n", ans.f[i]);
  return 0;
}