题目大意
给定一个只包括 $a$ 和 $b$ 的字符串,求满足以下条件的子序列的数量:
- 位置和字符都关于某条轴对称。
- 不是连续的一段。
题解
首先,如果某个子序列是连续的一段,那么它其实就是原字符串的一个子回文串。我们可以使用 Manacher 算法首先计算出所有的回文串数量,最后再减去即可。这样我们就去除了第二条限制。
考虑第一条限制:假如我们关于某条轴有 $a$ 对字符对称,那么实际能组成的子序列就有 $2^a-1$ 种。(每对字符可以选可以不选,并且要减去空序列)
那么如何快速计算关于某条轴对称的字符对数呢?发现字符集为 $2$,那么我们可以直接对每个字符单独考虑:
设 $d_i$ 为第 $i$ 个位置是否为 $a$,如果我们假设对称轴为 $k$($k$ 可以为一个分母为 $2$ 的分数,即对称轴在两个字符中间),那么以 $k$ 为对称轴的字符对数就是:
\[\sum_{\frac{i+j}{2}=k}d_id_j\]发现这个东西比较像卷积的形式。这个分数看起来很不爽,但是可以考虑我们在进行 Manacher 的时候,已经对字符串进行了补位,这样我们可以设 $k$ 为补位后的字符串的对称轴,这样我们实际要求的答案就是:
\[\sum_{k=1}^{2n+1}\sum_{i+j=k}d_id_j\]化为卷积的形式:
\[\sum_{k=1}^{2n+1}\sum_{i=1}^kd_id_{k-i}\]使用 FFT 或者 NTT 进行优化即可。
注意这个式子会将不在同一位置的两个字母重复计算一次,但是在同一位置的字母只会计算一次,所以我们要除以 $2$ 并向上取整(向上取整的原因是在计算回文串时已经计算了一个字母的数量了,如果这里向下取整就会少计算一个字母的数量,算出来的结果就会多减一些数。)
字符为 $b$ 的情况是一样的,最后将 $a$ 和 $b$ 的数量都加起来一起除以 $2$ 向上取整就是关于这个对称轴对称的字符对数。
最后减去回文串的数量就可以了。
复杂度 $O(n\log n)$。
代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400105;
int n;
char ch[MAXN];
char ch2[MAXN];
int r[MAXN];
void manacher() {
for (int i = 1, m = 0, c; i <= 2 * n + 1; i++) {
r[i] = i > m ? 1 : min(m - i, r[2 * c - i]);
while (ch2[i + r[i]] == ch2[i - r[i]]) r[i]++;
if (i + r[i] > m) {
m = i + r[i];
c = i;
}
}
}
const int P = 998244353, G = 3, P2 = 1000000007;
int qpow(int a, int b, int P = ::P) {
int ans = 1;
while (b) {
if (b & 1) ans = 1ll * ans * a % P;
a = 1ll * a * a % P;
b >>= 1;
}
return ans;
}
const int GI = qpow(G, P - 2);
int R[MAXN];
struct Polynomial {
vector<int> a;
int len;
int& operator[](int b) { return a[b]; }
Polynomial(int len = 0) : len(len) { a.resize(len + 1); }
void set(int len) { this->len = len, a.resize(len + 1); }
void ntt(int limit, bool rev) {
set(limit);
for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
int step = qpow(rev ? GI : G, (P - 1) / (mid << 1));
for (int l = 0; l < limit; l += (mid << 1)) {
int w = 1;
for (int i = 0; i < mid; i++, w = 1ll * w * step % P) {
int x = a[l + i], y = 1ll * w * a[l + i + mid] % P;
a[l + i] = (x + y) % P, a[l + i + mid] = (x - y + P) % P;
}
}
}
if (rev) {
int nrev = qpow(limit, P - 2);
for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * nrev % P;
}
}
Polynomial operator*(Polynomial b) {
Polynomial a = *this, c; int n = a.len + b.len;
int limit = 1; while (limit <= n) limit <<= 1; c.set(limit);
for (int i = 0; i < limit; i++)
R[i] = (R[i >> 1] >> 1) | ((i & 1) * limit >> 1);
a.ntt(limit, false), b.ntt(limit, false);
for (int i = 0; i < limit; i++) c[i] = 1ll * a[i] * b[i] % P;
c.ntt(limit, true);
c.set(n);
return c;
}
void print() {
for (int i : a) printf("%d ", i);
printf("\n");
}
}a, b;
long long ans = 0;
int main() {
scanf("%s", ch + 1);
n = strlen(ch + 1);
ch2[0] = '?';
for (int i = 1; i <= 2 * n + 1; i++) ch2[i] = (i & 1) ? '#' : ch[i >> 1];
manacher();
a.set(n), b.set(n);
for (int i = 1; i <= n; i++) {
if (ch[i] == 'a') {
a[i]++;
} else {
b[i]++;
}
}
a = a * a, b = b * b;
for (int i = 1; i <= 2 * n + 1; i++)
ans = (1ll * ans + qpow(2, ((a[i] + b[i] + 1) / 2), P2) - r[i] / 2 - 1 + P2) % P2;
printf("%lld\n", ans);
return 0;
}