title: The Child and Binary Tree tags:
Our child likes computer science very much, especially he likes binary trees.
Consider the sequence of $n$ distinct positive integers: $c_1, c_2, \ldots, c_n$. The child calls a vertex-weighted rooted binary tree good if and only if for every vertex $v$, the weight of $v$ is in the set ${c_1, c_2, \ldots, c_n}$. Also our child thinks that the weight of a vertex-weighted tree is the sum of all vertices' weights.
Given an integer $m$, can you for all $s \ (1 \le s \le m)$ calculate the number of good vertex-weighted rooted binary trees with weight $s$? Please, check the samples for better understanding what trees are considered different.
We only want to know the answer modulo $998244353$ ($7 \times 17 \times 2^{23}+1$, a prime number).
有$n$种点,每种点有无限个,第$i$种点的权值为$c_i$。定义一棵二叉树的权值等于它所有点的权值之和。求对于所有$s \in [1, m]$,权值为$s$的二叉树有几棵。两棵二叉树不同当且仅当它们左子树或右子树不同,或者根节点权值不同。
数据范围:$1 \le n, m, c_i \le 10^5$。
令$f(x)$表示权值为$x$的二叉树个数,$F(x)$为其生成函数($F(x)=\sum_{i \ge 0} f(i)x^i$)。
令$C(x)$为给定$c$的集合的生成函数($C(x)=\sum_{i=1}^n x^{c_i}$)。
根据 DP 转移方程,易知
$$ f(x)=\sum_{w \in {c_1, c_2, \ldots, cn}} \sum{i=0}^{x-w} f(i)f(x-w-i) $$
即
$$ F(x)=C(x)F(x)^2+1 $$
解得
$$ F(x)={1 \pm \sqrt{1-4C(x)} \over 2C(x)}={2 \over 1 \pm \sqrt{1-4C(x)}} $$
显然,若取减号,则当$x$趋近$0$时分母为$0$,因此只能取加号。接着就是多项式开根和多项式求逆了。
假设已知$G_0F \equiv 1 \pmod {x^{\lceil n/2 \rceil}}$
$G-G_0 \equiv 0 \pmod {x^{\lceil n/2 \rceil}}$
$G^2-2GG_0+G_0^2 \equiv 0 \pmod {x^n}$
$G-2G_0+G_0^2F \equiv 0 \pmod {x^n}$
$G \equiv 2G_0-G_0^2F \pmod {x^n}$
假设已知$G_0^2 \equiv F \pmod {x^{\lceil n/2 \rceil}}$
$(G_0^2-F)^2 \equiv 0 \pmod {x^n}$
$(G_0^2+F)^2 \equiv 4G_0^2F \pmod {x^n}$
$\left({G_0^2+F \over 2G_0}\right)^2 \equiv F \pmod {x^n}$
$G \equiv {G_0+G_0^{-1}F \over 2} \pmod {x^n}$
#include <algorithm>
#include <cstdio>
#include <cstring>
static const int N = 500000;
static const int MOD = 998244353;
static const int G = 3;
static const int INV2 = 499122177;
int n, m, c[N], C[N], rev[N], wn[N], tmp[N], tmp2[N], tmp3[N];
int power(int a, int b) {
int ret = 1;
for (a %= MOD, b %= MOD - 1; b; b >>= 1)
b & 1 && (ret = 1ll * ret * a % MOD), a = 1ll * a * a % MOD;
return ret;
}
void init(int &n) {
int m = n << 1, l = 0;
for (n = 1; n < m; n <<= 1, ++l)
;
for (int i = 1; i < n; ++i)
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (l - 1);
}
void ntt(int *a, int n, bool inv) {
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
wn[0] = 1, wn[1] = power(G, (MOD - 1) / n);
for (int i = 2; i < n >> 1; ++i)
wn[i] = 1ll * wn[i - 1] * wn[1] % MOD;
for (int i = 1; i < n; i <<= 1)
for (int j = 0; j < n; j += i << 1)
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = 1ll * wn[n / (i << 1) * k] * a[j + k + i] % MOD;
a[j + k] = (x + y) % MOD, a[j + k + i] = (MOD + x - y) % MOD;
}
if (inv) {
for (int i = 1; i < n >> 1; ++i)
std::swap(a[i], a[n - i]);
int rec = power(n, MOD - 2);
for (int i = 0; i < n; ++i)
a[i] = 1ll * a[i] * rec % MOD;
}
}
void get_inv(int *f, int *g, int n) {
if (n == 1)
return void(g[0] = power(f[0], MOD - 2));
int rec = n;
get_inv(f, g, (n + 1) >> 1), init(n);
for (int i = (rec + 1) >> 1; i < n; ++i)
g[i] = 0;
for (int i = 0; i < rec; ++i)
tmp[i] = f[i];
for (int i = rec; i < n; ++i)
tmp[i] = 0;
ntt(g, n, 0), ntt(tmp, n, 0);
for (int i = 0; i < n; ++i)
g[i] = 1ll * g[i] * (MOD + 2 - 1ll * g[i] * tmp[i] % MOD) % MOD;
ntt(g, n, 1);
for (int i = rec; i < n; ++i)
g[i] = 0;
}
void get_sqrt(int *f, int *g, int n) {
if (n == 1)
return void(g[0] = 1);
int rec = n;
get_sqrt(f, g, (n + 1) >> 1);
for (int i = 0; i<(n + 1)>> 1; ++i)
tmp2[i] = g[i];
for (int i = (n + 1) >> 1; i < n; ++i)
tmp2[i] = 0;
get_inv(tmp2, tmp3, n), init(n);
for (int i = (rec + 1) >> 1; i < n; ++i)
g[i] = 0;
for (int i = 0; i < rec; ++i)
tmp2[i] = f[i];
for (int i = rec; i < n; ++i)
tmp2[i] = 0;
ntt(tmp2, n, 0), ntt(tmp3, n, 0);
for (int i = 0; i < n; ++i)
tmp3[i] = 1ll * tmp3[i] * tmp2[i] % MOD;
ntt(tmp3, n, 1);
for (int i = 0; i < rec; ++i)
g[i] = 1ll * (g[i] + tmp3[i]) * INV2 % MOD;
for (int i = rec; i < n; ++i)
g[i] = 0;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n; ++i)
scanf("%d", &c[i]);
for (int i = 0; i < n; ++i)
if (c[i] <= m)
++C[c[i]];
for (int i = 1; i <= m; ++i)
C[i] = (MOD - (C[i] << 2)) % MOD;
C[0] = 1;
get_sqrt(C, c, m + 1), ++c[0], get_inv(c, C, m + 1);
for (int i = 1; i <= m; ++i)
printf("%d\n", (C[i] << 1) % MOD);
return 0;
}