首页 > 题解 > bzoj 4566 [Haoi2016]找相同字符

bzoj 4566 [Haoi2016]找相同字符

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

Output

输出一个整数表示答案

Sample Input

aabb

bbaa

Sample Output

10

题解

对一个串A建自动机,另一个串B在上面匹配

考虑怎么统计答案

对于当前匹配到的点,那么它parent树中的祖先代表的串现在肯定也出现了

对于每个出现的点,它代表了maxs[x]-mins[x]+1个串(mins[x]=maxs[fa[x]]+1)

这些串出现了|right(x)|次,贡献就是|right(x)|*(maxs[x]-maxs[fa[x]])

因为要统计祖先的,所以把祖先的也累加到这个节点即可

#include <cstdio>
#include <cstring>
#define N 200010
using namespace std;
int dis[N<<1],fa[N<<1],ch[N<<1][26],sum[N<<1],siz[N<<1],sz=1,root=1,last=root,len;
char s[N],s1[N];
void insert(int x)
{
    int now=++sz,pre=last;last=now;
    dis[now]=dis[pre]+1;siz[now]=1;
    for (;pre && !ch[pre][x];pre=fa[pre]) ch[pre][x]=now;
    if (!pre) fa[now]=root;
    else if (dis[ch[pre][x]]==dis[pre]+1) fa[now]=ch[pre][x];
    else
    {
        int q=ch[pre][x],nows=++sz;dis[nows]=dis[pre]+1;
        memcpy(ch[nows],ch[q],sizeof ch[nows]);
        fa[nows]=fa[q],fa[q]=fa[now]=nows;
        for (;pre && ch[pre][x]==q;pre=fa[pre]) ch[pre][x]=nows;
    }
}
int ton[N<<1],ti[N<<1];
void work()
{
    for (int i=1;i<=sz;i++) ton[dis[i]]++;
    for (int i=1;i<=len;i++) ton[i]+=ton[i-1];
    for (int i=1;i<=sz;i++) ti[ton[dis[i]]--]=i;
    for (int i=sz;i;i--) siz[fa[ti[i]]]+=siz[ti[i]];
    for (int i=1;i<=sz;i++)
        sum[ti[i]]=sum[fa[ti[i]]]+siz[ti[i]]*(dis[ti[i]]-dis[fa[ti[i]]]);
}
void ask(char s[])
{
    long long ans=0,l=0,le=strlen(s+1);
    for (int i=1,now=root;i<=le;i++)
    {
        if (ch[now][s[i]-'a'])
            now=ch[now][s[i]-'a'],l++;
        else
        {
            for (;now && !ch[now][s[i]-'a'];now=fa[now]);
            if (!now) now=root,l=0;
            else l=dis[now]+1,now=ch[now][s[i]-'a'];
        }
        if (now!=root) ans+=sum[fa[now]]+siz[now]*(l-dis[fa[now]])*1ll;
    }
    printf("%lld\n",ans);
}
main()
{
    scanf("%s",s+1),len=strlen(s+1);
    for (int i=1;i<=len;i++) insert(s[i]-'a');
    scanf("%s",s1+1);
    work();
    ask(s1);
}