[Codeforces]662C - Binary Table(FWT)


一些套路的整合题,是一个好题。
题意:
给定一个\(n\times m\)的01矩阵,每次可以选择一行或者一列进行取反,问任意进行操作后,矩阵中剩下的1最少有几个。
\(n\le 20, m\le 10^5\)

先进行一下转化,首先注意到\(n\)是很小的,有一个贪心策略是,确定了行的取反状态后,列的取反方案其实确定了,每一列,假如取反后1比较少就取反。
当每一列和行反转状态用二进制数表达之后,令行的翻转状态为\(x\),答案就变成

\[\sum_{i = 1}^{m}f(a_i \oplus x) \]

其中

\[f(x) = min(popcount(x), n - popcount(x)) \]

怎么继续优化?
利用FWT常用的一个套路,\(a\oplus b=c\)推出\(a\oplus c=b\)

\[a_i\oplus x=j \]

\[a_i\oplus j=x \]

这时我们已经把等式左边的一个变量凑到外面去了。
这个\(j\)其实是一个任意数,\(f(j)\)对答案的贡献其实就跟\(a_i\)的数量有关。
枚举\(a_i\)的值可以得到

\[f(x) = \sum_{i\oplus j = x}f(j)\times cnt(i) \]

这玩意就可以进行异或卷积了。

#include 
#define pt(x) cout << x << endl;
#define Mid ((l + r) / 2)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
int read() {
	char c; int num, f = 1;
	while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
	while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
	return f * num;
}
const int N = (1 << 21) + 1009;
const int M = 2e5 + 1009;
const int mod = 998244353;
int f[N], cnt[N], n, m, g[29][M];
int Pow(int a, int p) {
	int ans = 1;
	for( ; p; p >>= 1, a = 1ll * a * a % mod)
		if(p & 1)
			ans = 1ll * ans * a % mod;
	return ans % mod;
}
void FWT_xor(int *A, int n, int type) {
	int inv_2 = Pow(2, mod - 2);
	for(int m = 1; m < n; m <<= 1) {
		for(int i = 0; i < n; i += 2 * m) {
			for(int j = 0; j < m; j++) {
				int x = A[i + j], y = A[i + j + m];
				A[i + j] = (1ll * x + y) * (type == 1 ? 1 : inv_2) % mod;
				A[i + j + m] = (1ll * x - y + mod) * (type == 1 ? 1 : inv_2) % mod;
			}
		}
	}
}
signed main()
{
	n = read(); m = read();
	for(int i = 1; i <= n; i++) 
		for(int j = 1; j <= m; j++) 
			scanf("%1d", &g[i][j]);
	for(int j = 1; j <= m; j++) {
		int a = 0;
		for(int i = 1; i <= n; i++) 
			a = a * 2 + g[i][j];
		cnt[a]++;
	}
	int lim = 1 << n;
	for(int i = 0; i < lim; i++) {
		f[i] = min(__builtin_popcount(i), n - __builtin_popcount(i));
	}
	FWT_xor(f, lim, 1);
	FWT_xor(cnt, lim, 1);
	for(int i = 0; i < lim; i++) f[i] = 1ll * f[i] * cnt[i] % mod;
	FWT_xor(f, lim, -1);
	int ans = 0x3f3f3f3f;
	for(int i = 0; i < (1 << n); i++) {
		ans = min(ans, f[i]);
	}
	printf("%d\n", ans);
	return 0;
}