首页 > 题解 > bzoj 2946 [Poi2000]公共串

bzoj 2946 [Poi2000]公共串

Description

给出几个由小写字母构成的单词,求它们最长的公共子串的长度。
任务:
l 读入单词
l 计算最长公共子串的长度
l 输出结果

Input

文件的第一行是整数 n,1<=n<=5,表示单词的数量。接下来n行每行一个单词,只由小写字母组成,单词的长度至少为1,最大为2000。

Output

仅一行,一个整数,最长公共子串的长度。

Sample Input

3

abcb

bca

acbc

Sample Output

题解

首先对其中的一个串建出SAM。拓扑排序一下。

每个点记录每个串最长匹配长度的最小值,最后找到所有点中最长的一个就可以了

注意有个细节。总结答案的时候,要更新一下Parent树上的节点的答案。因为匹配的时候可能是直接跳到这个节点上的,Parent有可能是没有更新的。所以要从拓扑序最大的那个开始,往前走,边走边更新答案。

那么为什么上个题不用这么更新?因为Parent树上的那个答案肯定比这个答案要小,上个题只要求最大的答案,这题因为有很多串,要全面考虑。

为什么只用更新一个Parent?不应该跳Parent树直到根?因为按照拓扑序走的,下次走到这个点的Parent的时候就会更新它的Parent了。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 2010
using namespace std;
int dis[N<<1],fa[N<<1],ch[N<<1][26],ans[N<<1],maxs[N<<1],last=1,sz=1,root=1,all,len,n;
char s[N];
void insert(int x)
{
    int now=++sz,pre=last;last=now;
    dis[now]=dis[pre]+1;
    for (;pre && !ch[pre][x];pre=fa[pre]) ch[pre][x]=now;
    if (!pre) fa[now]=root;
    else
    {
        int q=ch[pre][x];
        if (dis[q]==dis[pre]+1)
            fa[now]=q;
        else
        {
            int nows=++sz;dis[nows]=dis[pre]+1;
            memcpy(ch[nows],ch[q],sizeof ch[q]);
            fa[nows]=fa[q],fa[q]=fa[now]=nows;
            for (;pre && ch[pre][x]==q;pre=fa[pre]) ch[pre][x]=nows;
        }
    }
}
int ton[N<<1],ti[N<<1];
void topsort()
{
    for (int i=1;i<=sz;i++) ton[dis[i]]++;
    for (int i=1;i<=len;i++) ton[i]+=ton[i-1];
    for (int i=1;i<=sz;i++) ti[ton[dis[i]]--]=i;
}
void work(char s[])
{
    memset(maxs,0,sizeof maxs);
    int l=strlen(s+1),now=root,le=0;
    for (int i=1;i<=l;maxs[now]=max(maxs[now],le),i++)
        if (ch[now][s[i]-'a'])
            now=ch[now][s[i]-'a'],le++;
        else
        {
            for (;now && !ch[now][s[i]-'a'];now=fa[now]);
            if (!now) now=root,le=0;
            else le=dis[now]+1,now=ch[now][s[i]-'a'];
        }
    for (int i=sz;i;i--)
    {
        ans[ti[i]]=min(ans[ti[i]],maxs[ti[i]]);
        if (fa[ti[i]] && maxs[ti[i]]) maxs[fa[ti[i]]]=dis[fa[ti[i]]];
    }
}
main()
{
    scanf("%d",&n),n--;
    scanf("%s",s+1),len=strlen(s+1);
    for (int i=1;i<=len;i++) insert(s[i]-'a');
    for (int i=1;i<=sz;i++) ans[i]=dis[i];
    topsort();
    for (int i=1;i<=n;i++) scanf("%s",s+1),work(s);
    for (int i=1;i<=sz;i++) all=max(all,ans[i]);
    printf("%d\n",all);
}