题目大意

给定一个只包括 $a$ 和 $b$ 的字符串,求满足以下条件的子序列的数量:

  1. 位置和字符都关于某条轴对称。
  2. 不是连续的一段。

题解

首先,如果某个子序列是连续的一段,那么它其实就是原字符串的一个子回文串。我们可以使用 Manacher 算法首先计算出所有的回文串数量,最后再减去即可。这样我们就去除了第二条限制。

考虑第一条限制:假如我们关于某条轴有 $a$ 对字符对称,那么实际能组成的子序列就有 $2^a-1$ 种。(每对字符可以选可以不选,并且要减去空序列)

那么如何快速计算关于某条轴对称的字符对数呢?发现字符集为 $2$,那么我们可以直接对每个字符单独考虑:

设 $d_i$ 为第 $i$ 个位置是否为 $a$,如果我们假设对称轴为 $k$($k$ 可以为一个分母为 $2$ 的分数,即对称轴在两个字符中间),那么以 $k$ 为对称轴的字符对数就是:

\[\sum_{\frac{i+j}{2}=k}d_id_j\]

发现这个东西比较像卷积的形式。这个分数看起来很不爽,但是可以考虑我们在进行 Manacher 的时候,已经对字符串进行了补位,这样我们可以设 $k$ 为补位后的字符串的对称轴,这样我们实际要求的答案就是:

\[\sum_{k=1}^{2n+1}\sum_{i+j=k}d_id_j\]

化为卷积的形式:

\[\sum_{k=1}^{2n+1}\sum_{i=1}^kd_id_{k-i}\]

使用 FFT 或者 NTT 进行优化即可。

注意这个式子会将不在同一位置的两个字母重复计算一次,但是在同一位置的字母只会计算一次,所以我们要除以 $2$ 并向上取整(向上取整的原因是在计算回文串时已经计算了一个字母的数量了,如果这里向下取整就会少计算一个字母的数量,算出来的结果就会多减一些数。)

字符为 $b$ 的情况是一样的,最后将 $a$ 和 $b$ 的数量都加起来一起除以 $2$ 向上取整就是关于这个对称轴对称的字符对数。

最后减去回文串的数量就可以了。

复杂度 $O(n\log n)$。

代码

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400105;
int n;
char ch[MAXN];
char ch2[MAXN];
int r[MAXN];
void manacher() {
    for (int i = 1, m = 0, c; i <= 2 * n + 1; i++) {
        r[i] = i > m ? 1 : min(m - i, r[2 * c - i]);
        while (ch2[i + r[i]] == ch2[i - r[i]]) r[i]++;
        if (i + r[i] > m) {
            m = i + r[i];
            c = i;
        }
    }
}
const int P = 998244353, G = 3, P2 = 1000000007;
int qpow(int a, int b, int P = ::P) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
const int GI = qpow(G, P - 2);
int R[MAXN];
struct Polynomial {
    vector<int> a;
    int len;
    int& operator[](int b) { return a[b]; }
    Polynomial(int len = 0) : len(len) { a.resize(len + 1); }
    void set(int len) { this->len = len, a.resize(len + 1); }
    void ntt(int limit, bool rev) {
        set(limit);
        for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
        for (int mid = 1; mid < limit; mid <<= 1) {
            int step = qpow(rev ? GI : G, (P - 1) / (mid << 1));
            for (int l = 0; l < limit; l += (mid << 1)) {
                int w = 1;
                for (int i = 0; i < mid; i++, w = 1ll * w * step % P) {
                    int x = a[l + i], y = 1ll * w * a[l + i + mid] % P;
                    a[l + i] = (x + y) % P, a[l + i + mid] = (x - y + P) % P;
                }
            }
        }
        if (rev) {
            int nrev = qpow(limit, P - 2);
            for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * nrev % P;
        }
    }
    Polynomial operator*(Polynomial b) {
        Polynomial a = *this, c; int n = a.len + b.len;
        int limit = 1; while (limit <= n) limit <<= 1; c.set(limit);
        for (int i = 0; i < limit; i++) 
            R[i] = (R[i >> 1] >> 1) | ((i & 1) * limit >> 1);
        a.ntt(limit, false), b.ntt(limit, false);
        for (int i = 0; i < limit; i++) c[i] = 1ll * a[i] * b[i] % P;
        c.ntt(limit, true);
        c.set(n);
        return c;
    }
    void print() {
        for (int i : a) printf("%d ", i);
        printf("\n");
    }
}a, b;
long long ans = 0;
int main() {
    scanf("%s", ch + 1);
    n = strlen(ch + 1);
    ch2[0] = '?';
    for (int i = 1; i <= 2 * n + 1; i++) ch2[i] = (i & 1) ? '#' : ch[i >> 1];
    manacher();
    a.set(n), b.set(n);
    for (int i = 1; i <= n; i++) {
        if (ch[i] == 'a') {
            a[i]++;
        } else {
            b[i]++;
        }
    }
    a = a * a, b = b * b;
    for (int i = 1; i <= 2 * n + 1; i++) 
        ans = (1ll * ans + qpow(2, ((a[i] + b[i] + 1) / 2), P2) - r[i] / 2 - 1 + P2) % P2;
    printf("%lld\n", ans);
    return 0;
}