首页 > 笔记 > 树剖进阶总结

树剖进阶总结

再来一篇

树链剖分

我们先搞棵树

如果只要子树操作?可以发现子树是DFS序上一段连续区间。。

而且区间的大小就是子树的大小(很明显啊

那么怎么处理路径问题?

我们可以发现dfs序是一些小的路径拼起来的

实际上,每个点第一个被访问的子树就会和它形成一条路径,我们就把这棵树剖成了很多条路径,询问的一条路径就可以分成这些小路径。

我们就是要找一种比较好的方法,让每条询问的路径都分成尽量少的短。

怎么?随机?感觉非常不靠谱。

那就按子树大小加权随机?用不到。直接剖较大的就可以了

一个点的子树最大的儿子叫重儿子,这条边叫做重边。一条由重边组成的链叫做重链。

显然一条路径只会有$O(\log n)$条重边。

然后我们dfs一次求出重儿子,再dfs一遍求出dfs序。这样每个重链在dfs序上就是连续的区间。用数据结构维护下dfs序就可以了。

一般的题用线段树就可以维护了。

例题

luogu3384

就是模版啊。。

#include <cstdio>
#include <algorithm>
#include <queue>
#define M 200010
#define N 200010
#define min(x,y) ((x<y)?(x):(y))
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
using namespace std;
typedef long long LL;
LL inline read()
{
    LL x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
struct node
{
    LL to,next;
}e[M];
LL tot,st[M],deep[N],siz[N],fa[N],son[N],val[N],top[N],v[N],id[N],pos[N];
LL n,m,x,y,z,root,mod,com,ans,ql,qr;
void add(LL x,LL y)
{
    e[++tot].to=y;
    e[tot].next=st[x];
    st[x]=tot;
}
LL dfs_id=0;
void dfs1(LL now,LL pre,LL depth)
{
    LL maxs=-0x3f3f3f3f;
    deep[now]=depth;
    siz[now]=1;
    fa[now]=pre;
    for (LL i=st[now];i;i=e[i].next)
        if (e[i].to!=pre)
        {
            dfs1(e[i].to,now,depth+1);
            siz[now]+=siz[e[i].to];
            if (maxs<siz[e[i].to])
                son[now]=e[i].to,maxs=siz[e[i].to];
        }
}
void dfs2(LL now,LL tops)
{
    top[now]=tops;
    id[++dfs_id]=now;
    pos[now]=dfs_id;
    val[dfs_id]=v[now];
    if (son[now]!=0)
        dfs2(son[now],tops);
    for (LL i=st[now];i;i=e[i].next)
        if (e[i].to!=fa[now]&&e[i].to!=son[now])
            dfs2(e[i].to,e[i].to);
}
LL seg[4*N];
LL mark[4*N];
void pushup(LL rt)
{
    seg[rt]=(seg[rt<<1]+seg[rt<<1|1]);
}
void pushdown(LL rt,LL x)
{
    if (mark[rt])
    {
        mark[rt<<1]=(mark[rt<<1]+mark[rt]);
        mark[rt<<1|1]=(mark[rt<<1|1]+mark[rt]);
        seg[rt<<1]=(seg[rt<<1]+mark[rt]*(x-x/2));
        seg[rt<<1|1]=(seg[rt<<1|1]+mark[rt]*(x/2));
        mark[rt]=0;
    }
}
void build(LL rt,LL l,LL r)
{
    if (l==r)
        seg[rt]=val[l];
    else
    {
        LL mid=(l+r)/2;
        build(lson);
        build(rson);
        pushup(rt);
    }
}
LL query(LL rt,LL l,LL r,LL L,LL R)
{
    if (l>=L && r<=R)
        return seg[rt];
    else
    {
        pushdown(rt,r-l+1);
        LL mid=(r+l)/2,x=0;
        if (mid>=L)
            x+=query(lson,L,R);
        if (mid<R)
            x+=query(rson,L,R);
        return x;
    }
}
void update(LL rt,LL l,LL r,LL L,LL R,LL x)
{
    if (l>=L && r<=R)
        seg[rt]=(seg[rt]+(r-l+1)*x),mark[rt]=(mark[rt]+x);
    else
    {
        pushdown(rt,r-l+1);
        LL mid=(r+l)/2;
        if (mid>=L)
            update(lson,L,R,x);
        if (mid<R)
            update(rson,L,R,x);
        pushup(rt);
    }
}
void add_path(LL x,LL y,LL z)
{
    while(top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]]) swap(x,y);
        update(1,1,n,pos[top[x]],pos[x],z);
        x=fa[top[x]];
    }
    if (deep[x]<deep[y]) swap(x,y);
    update(1,1,n,pos[y],pos[x],z);
}
LL sum_path(LL x,LL y)
{
    LL ans=0;
    while(top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]]) swap(x,y);
        ans=(ans+query(1,1,n,pos[top[x]],pos[x]))%mod;
        x=fa[top[x]];
    }
    if (deep[x]<deep[y]) swap(x,y);
    ans=(ans+query(1,1,n,pos[y],pos[x]))%mod;
    return ans;
}
void add_sub(LL x,LL y)
{
    update(1,1,n,pos[x],pos[x]+siz[x]-1,y);
}
LL sum_sub(LL x)
{
    return query(1,1,n,pos[x],pos[x]+siz[x]-1);
}
main()
{
    scanf("%d%d%d%d",&n,&m,&root,&mod);
    for (LL i=1;i<=n;i++)
        scanf("%d",&v[i]);
    for (LL i=1;i<=n-1;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    deep[0]=-1;
    dfs1(root,0,1);
    dfs2(root,root);
    build(1,1,n);
    for (LL i=1;i<=m;i++)
    {
        ans=0;
        scanf("%d",&com);
        if (com==1)
            x=read(),y=read(),z=read(),add_path(x,y,z);
        else if (com==2)
            x=read(),y=read(),printf("%d\n",sum_path(x,y)%mod);
        else if (com==3)
            x=read(),y=read(),add_sub(x,y);
        else if (com==4)
            x=read(),printf("%d\n",sum_sub(x)%mod);
    }
}

luogu3250

查询的是不包含x的链的权值最大值。
树链剖分后,每条链变成O(logn)个区间,那么未被这个链包含的也有logn个区间,在这logn个区间上做修改即可。
因为有删除操作,每个线段树节点开一个堆维护。

#include <cstdio>
#include <queue>
#include <algorithm>
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define N 200010
#define M 500010
using namespace std;
struct node
{
    int to,next;
}e[M];
struct qst
{
    int l,r;
}q[N];
int tot,st[M],deep[N],siz[N],son[N],val[N],top[N],v[N],id[N],pos[N],ed[N];
int n,m,x,y,z,root,mod,com,ans,cnt;
struct cmp{bool operator () (int a,int b) {return val[a]<val[b];}};
priority_queue<int,vector<int>,cmp> pri[4*N];
bool comp(qst a,qst b){return a.l<b.l;}
void swap(int &a,int &b){int t=a;a=b;b=t;}
void add(int x,int y)
{
    e[++tot].to=y;
    e[tot].next=st[x];
    st[x]=tot;
}
int fa[N];
void dfs1(int now)
{
    int maxs=-1;
    siz[now]=1;
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=fa[now])
        {
            deep[e[i].to]=deep[now]+1;
            fa[e[i].to]=now;
            dfs1(e[i].to);
            siz[now]+=siz[e[i].to];
            if (maxs<siz[e[i].to])
                maxs=siz[e[i].to],son[now]=e[i].to;
        }
}
void dfs2(int now,int tops)
{
    id[++cnt]=now;
    pos[now]=cnt;
    // val[cnt]=v[now];
    top[now]=tops;
    if (son[now])
        dfs2(son[now],tops);
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=fa[now] && e[i].to!=son[now])
            dfs2(e[i].to,e[i].to);
}
void insert(int rt,int l,int r,int L,int R,int x)
{
    if (L<=l && r<=R)
    {
        pri[rt].push(x);
        return;
    }
    int mid=(l+r)/2;
    if (L<=mid)
        insert(lson,L,R,x);
    if (R>mid)
        insert(rson,L,R,x);
}
int Ans;
void query(int rt,int l,int r,int x)
{
    while(!pri[rt].empty())
    {
        int now=pri[rt].top();
        if (ed[now])
            pri[rt].pop();
        else
        {
            Ans=max(Ans,val[now]);
            break;
        }
    }
    if (l==r) return;
    int mid=(l+r)/2;
    if (x<=mid)
        query(lson,x);
    else
        query(rson,x);
}
main()
{
    scanf("%d%d",&n,&m);
    root=1;
    for (int i=1;i<n;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    dfs1(root);
    dfs2(root,root);
    for (int i=1;i<=m;i++)
    {
        int com;
        scanf("%d",&com);
        if (com==0)
        {
            scanf("%d%d%d",&x,&y,&val[i]);
            int w=0;
            while(top[x]!=top[y])
            {
                if (deep[top[x]]<deep[top[y]]) swap(x,y);
                q[++w]=(qst){pos[top[x]],pos[x]};
                x=fa[top[x]];
            }
            if (deep[x]<deep[y]) swap(x,y);
            q[++w]=(qst){pos[y],pos[x]};
            q[++w]=(qst){n+1,n+1};
            sort(q+1,q+w+1,comp);
            for (int j=1;j<=w;j++)
                if (q[j-1].r<q[j].l-1)
                    insert(1,1,n,q[j-1].r+1,q[j].l-1,i);
        }
        else if (com==1)
        {
            scanf("%d",&x);
            ed[x]=1;
        }
        else
        {
            scanf("%d",&x);
            Ans=-1;
            query(1,1,n,pos[x]);
            printf("%d\n",Ans);
        }
    }
}

luogu3313

这题是非常有趣的一道题。我们对于每一种宗教开一棵线段树。但是不能开全,就是动态开点的线段树。每次要改的时候就加入新点。(有种主席树即视感

#include <cstdio>
#include <algorithm>
#include <queue>
#define M 200010
#define N 100010
#define min(x,y) ((x<y)?(x):(y))
#define lson ls[rt],l,mid
#define rson rs[rt],mid+1,r
using namespace std;
int inline read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
struct node
{
    int to,next;
}e[M];
int tot,st[M],deep[N],siz[N],fa[N],son[N],val[N],top[N],v[N],id[N],pos[N];
int r[N],rel[N],rt[N];
int n,m,x,y,z,root,mod,com,ans,ql,qr;
void add(int x,int y)
{
    e[++tot].to=y;
    e[tot].next=st[x];
    st[x]=tot;
}
int dfs_id=0;
void dfs1(int now,int pre,int depth)
{
    int maxs=-0x3f3f3f3f;
    deep[now]=depth;
    siz[now]=1;
    fa[now]=pre;
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=pre)
        {
            dfs1(e[i].to,now,depth+1);
            siz[now]+=siz[e[i].to];
            if (maxs<siz[e[i].to])
                son[now]=e[i].to,maxs=siz[e[i].to];
        }
}
void dfs2(int now,int tops)
{
    top[now]=tops;
    id[++dfs_id]=now;
    pos[now]=dfs_id;
    val[dfs_id]=v[now];
    rel[dfs_id]=r[now];
    if (son[now]!=0)
        dfs2(son[now],tops);
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=fa[now]&&e[i].to!=son[now])
            dfs2(e[i].to,e[i].to);
}
int sum[40*N],maxs[40*N],ls[40*N],rs[40*N],cnt;
void add_tree(int &now,int pre,int l,int r,int pos,int v)
{
    now=++cnt;
    if (l==r)
        sum[now]=v,maxs[now]=v;
    else
    {
        int mid=l+r>>1;
        if (pos<=mid)
            add_tree(ls[now],ls[pre],l,mid,pos,v),rs[now]=rs[pre];
        else
            add_tree(rs[now],rs[pre],mid+1,r,pos,v),ls[now]=ls[pre];
        sum[now]=sum[ls[now]]+sum[rs[now]];
        maxs[now]=max(maxs[ls[now]],maxs[rs[now]]);
    }
}
int query_s(int rt,int l,int r,int L,int R)
{
    if (!rt) return 0;
    if (l>=L && r<=R)
        return sum[rt];
    else
    {
        int mid=(r+l)/2,x=0;
        if (mid>=L)
            x+=query_s(lson,L,R);
        if (mid<R)
            x+=query_s(rson,L,R);
        return x;
    }
}
int query_m(int rt,int l,int r,int L,int R)
{
    if (!rt) return 0;
    if (l>=L && r<=R)
        return maxs[rt];
    else
    {
        int mid=(r+l)/2,x=0;
        if (mid>=L)
            x=max(x,query_m(lson,L,R));
        if (mid<R)
            x=max(x,query_m(rson,L,R));;
        return x;
    }
}
int sum_path(int x,int y,int re)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]]) swap(x,y);
        ans=(ans+query_s(rt[re],1,n,pos[top[x]],pos[x]));
        x=fa[top[x]];
    }
    if (deep[x]<deep[y]) swap(x,y);
    ans=(ans+query_s(rt[re],1,n,pos[y],pos[x]));
    return ans;
}
int max_path(int x,int y,int re)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]]) swap(x,y);
        ans=max(ans,query_m(rt[re],1,n,pos[top[x]],pos[x]));
        x=fa[top[x]];
    }
    if (deep[x]<deep[y]) swap(x,y);
    ans=max(ans,query_m(rt[re],1,n,pos[y],pos[x]));
    return ans;
}
main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
        scanf("%d%d",&v[i],&r[i]);
    for (int i=1;i<=n-1;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    deep[0]=-1;
    root=1;
    dfs1(root,0,1);
    dfs2(root,root);
    for (int i=1;i<=n;i++)
        add_tree(rt[rel[i]],rt[rel[i]],1,n,i,val[i]);
    char c1,c2;int a,b;
    for (int i=1;i<=m;i++)
    {
        getchar();
        ans=0;
        scanf("%c%c%d%d",&c1,&c2,&a,&b);
        if (c1=='Q'&&c2=='S')
            printf("%d\n",sum_path(a,b,rel[pos[a]]));
        else
        if (c1=='Q'&&c2=='M')
            printf("%d\n",max_path(a,b,rel[pos[a]]));
        else
        if (c1=='C'&&c2=='C')
        {
            add_tree(rt[rel[pos[a]]],rt[rel[pos[a]]],1,n,pos[a],0);
            rel[pos[a]]=b;
            add_tree(rt[b],rt[b],1,n,pos[a],val[pos[a]]);
        }
        else
        if (c1=='C'&&c2=='W')
        {
            val[pos[a]]=b;
            add_tree(rt[rel[pos[a]]],rt[rel[pos[a]]],1,n,pos[a],b);
        }
    }
}