title: Four Loop tags:
Given two sequences $a$ and $b$ of equal length $n$, find the
$$ \sum{x=1}^n \sum{y=1}^n \sum{z=1}^n \sum{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)} $$
给定两个长度为$n$的序列$a$和$b$,求$\sum{x=1}^n \sum{y=1}^n \sum{z=1}^n \sum{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$。
数据范围:$1 \le n \le 10^5, \; 1 \le a_i \le 500, \; 1 \le b_i \le 500$。
$a_x+a_y+a_z+a_w$的取值范围为$[4,2000]$,$b_x \oplus b_y \oplus b_z \oplus b_w$的取值范围为$[0,511]$,可以分别计算每种情况出现的次数再求和。
令$f_{1,i,j}$表示有多少个$x$满足$a_x=i \land b_x=j$。对于$k \ge 2$,令
$$ f{k,i,j}=\sum{i_1+i2=i} \sum{j_1 \oplus j2=j} f{k-1,i_1,j1}f{1,i_2,j_2} $$
$f{4,i,j}$即要求的每种情况出现的次数。转移方程在一维上是乘法卷积,另一维上是异或卷积,可以分别用 FFT 和 FWT 进行处理。总时间复杂度为$O(a{\max}b{\max}\log(a{\max}b_{\max}))$。
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
int const N = 100005, M = 512, MOD = 998244353, G = 3, INV2 = 499122177;
int a[N], b[N], rev[M << 2];
int power(int a, int b) {
int ret = 1;
for (; b; b >>= 1) {
if (b & 1) {
ret = 1ll * ret * a % MOD;
}
a = 1ll * a * a % MOD;
}
return ret;
}
int wn[M << 2], A[M << 2], f[M << 2][M];
void init(int n) {
int len = 1, p = 0;
for (; len < n; len <<= 1, ++p) ;
for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << p - 1;
}
wn[0] = 1, wn[1] = power(G, (MOD - 1) / len);
for (int i = 2; i < len >> 1; ++i) {
wn[i] = 1ll * wn[i - 1] * wn[1] % MOD;
}
}
void fft(int *a, int len, bool inv = 0) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(a[i], a[rev[i]]);
}
}
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < len; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = 1ll * wn[len / (i << 1) * k] * a[j + i + k] % MOD;
a[j + k] = (x + y) % MOD;
a[j + i + k] = (MOD + x - y) % MOD;
}
}
}
if (inv) {
std::reverse(a + 1, a + len);
int inv = power(len, MOD - 2);
for (int i = 0; i < len; ++i) {
a[i] = 1ll * a[i] * inv % MOD;
}
}
}
void fwt(int *a, int len, bool inv = 0) {
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < len; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = a[j + i + k];
a[j + k] = (x + y) % MOD;
a[j + i + k] = (MOD + x - y) % MOD;
if (inv) {
a[j + k] = 1ll * a[j + k] * INV2 % MOD;
a[j + i + k] = 1ll * a[j + i + k] * INV2 % MOD;
}
}
}
}
}
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n; ++i) {
scanf("%d", &b[i]);
f[a[i]][b[i]] = f[a[i]][b[i]] + 1;
}
init(M << 2);
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
A[j] = f[i][j];
}
fwt(A, M);
for (int j = 0; j < M; ++j) {
f[i][j] = A[j];
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < M << 2; ++j) {
A[j] = f[j][i];
}
fft(A, M << 2);
for (int j = 0; j < M << 2; ++j) {
f[j][i] = A[j];
}
}
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
f[i][j] = power(f[i][j], 4);
}
}
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
A[j] = f[i][j];
}
fwt(A, M, 1);
for (int j = 0; j < M; ++j) {
f[i][j] = A[j];
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < M << 2; ++j) {
A[j] = f[j][i];
}
fft(A, M << 2, 1);
for (int j = 0; j < M << 2; ++j) {
f[j][i] = A[j];
}
}
int ans = 0;
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
if (f[i][j]) {
ans = (ans + 1ll * f[i][j] * power(i, j)) % MOD;
}
}
}
printf("%d\n", ans);
return 0;
}