首页 > 笔记 > 树链剖分算法总结

树链剖分算法总结

树链剖分是一种把树剖分成重链和轻链,并用dfs序储存在线段树中的算法。它可以方便的处理树上路径和子树的问题。把树上数据存在线段树中的思想值得思考。

何为树链剖分?树链,就是树上路径,剖分,就是把树链剖分成轻链和重链。

记siz[v]表示以v为根的子树的节点数,dep[v]表示v的深度,top[v]表示v所在的重链的顶端节点,fa[v]表示v的父亲,son[v]表示重儿子,dfs_id[v]v的dfs序。

先介绍几个概念:

   重儿子:siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
   轻儿子:v的其它子节点。
   重边:点v与其重儿子的连边。
   轻边:点v与其轻儿子的连边。
   重链:由重边连成的路径。
   轻链:连接各个重链的边。

这样,很显然的我们就能发现

1.如果(v,u)为轻边,则siz[u] * 2 < siz[v];

2.从根到某一点的路径上轻链、重链的个数都不大于logn。

这两个很好的性质就可以在logn的复杂度下遍历任意一个路径。我们可以两个点同时向上跳,假如是重链就跳到top,不是就跳到父亲。直到跳到两点的top是同一个。跳的同时就可以用线段树维护一下极值、求和啥的。

图片来自网络

images

如何实现呢?

我们可以通过两个dfs实现

第一个

void dfs_1(int x,int f)
{
    siz[x]=1;
    fa[x]=f;
    for (int i=st[x];i;i=e[i].next)
        if (e[i].to!=f)
        {
            dep[e[i].to]=dep[x]+1;
            dfs_1(e[i].to,x);
            siz[x]+=siz[e[i].to];
            if (siz[e[i].to]>siz[son[x]])
                son[x]=e[i].to;
        }
}

在这个dfs中,可以把siz,fa,dep,son求出来

第二个

int tot2=0;
void dfs_2(int now,int tp)
{
    pre[++tot2]=now;
    dfs_id[now]=tot2;
    top[now]=tp;
    if (son[now])
        dfs_2(son[now],tp);
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=son[now] && e[i].to!=fa[now])
            dfs_2(e[i].to,e[i].to);
}
//基本思想就是现在是重链上的话,就用原来的top,不是重链上的就传自己作为top

这个dfs可以把top,dfs_id求出来

pre就是dfs_id的反函数。。它是当构造线段树时候用的。

构造线段树

void build(int rt,int l,int r)
{
    if (l==r)
        tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]];
    else
    {
        int mid=(l+r)/2;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
        pushup(rt);
    }
}

val数字存的是每个点的权值。因为线段树用的是dfs序,我们直接调用反函数就可以知道权值

例题是这道

bzoj2157

Description

Ray 乐忠于旅游,这次他来到了T 城。T 城是一个水上城市,一共有 N 个景点,有些景点之间会用一座桥连接。为了方便游客到达每个景点但又为了节约成本,T 城的任意两个景点之间有且只有一条路径。换句话说, T 城中只有N − 1 座桥。Ray 发现,有些桥上可以看到美丽的景色,让人心情愉悦,但有些桥狭窄泥泞,令人烦躁。于是,他给每座桥定义一个愉悦度w,也就是说,Ray 经过这座桥会增加w 的愉悦度,这或许是正的也可能是负的。有时,Ray 看待同一座桥的心情也会发生改变。现在,Ray 想让你帮他计算从u 景点到v 景点能获得的总愉悦度。有时,他还想知道某段路上最美丽的桥所提供的最大愉悦度,或是某段路上最糟糕的一座桥提供的最低愉悦度。

Input

输入的第一行包含一个整数N,表示T 城中的景点个数。景点编号为 0…N − 1。接下来N − 1 行,每行三个整数u、v 和w,表示有一条u 到v,使 Ray 愉悦度增加w 的桥。桥的编号为1…N − 1。|w| <= 1000。输入的第N + 1 行包含一个整数M,表示Ray 的操作数目。接下来有M 行,每行描述了一个操作,操作有如下五种形式: C i w,表示Ray 对于经过第i 座桥的愉悦度变成了w。 N u v,表示Ray 对于经过景点u 到v 的路径上的每一座桥的愉悦度都变成原来的相反数。 SUM u v,表示询问从景点u 到v 所获得的总愉悦度。 MAX u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最大愉悦度。 MIN u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最小愉悦度。测试数据保证,任意时刻,Ray 对于经过每一座桥的愉悦度的绝对值小于等于1000。

Output

对于每一个询问(操作S、MAX 和MIN),输出答案。

Sample Input

3

0 1 1

1 2 2

8

SUM 0 2

MAX 0 2

N 0 1

SUM 0 2

MIN 0 2

C 1 3

SUM 0 2

MAX 0 2
Sample Output

3

2

1

-1

5

3
HINT

一共有10 个数据,对于第i (1 <= i <= 10) 个数据, N = M = i * 2000。

 

这题有四个操作,区间最大,区间最小,区间和,区间反转

我们该如何操作呢?

void doit()
{
    int l,r,f1,f2,sum=0,maxs=-INF,mins=INF;
    scanf("%d%d",&l,&r),l++,r++;
    f1=top[l],f2=top[r];
    while(f1!=f2)
    {
        if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r);
        if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]);
        else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]);
        else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins);
        else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs);
        l=fa[f1],f1=top[l];
    }
    if (dep[l]>dep[r]) swap(l,r);
    if (l==r)
    {
        if (ch=='S') printf("%d\n",sum);
        if (ch=='I') printf("%d\n",mins);
        if (ch=='A') printf("%d\n",maxs);
    }
    else
    {
        l=son[l];
        if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]);
        else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum);
        else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins);
        else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs);
    }
}

首先,找到这两个节点,把高度低的那个往上跳,跳的时候操作一下这个链,最后直到top一样或到同一个点

注意一下,假如不在一个点的话最后还要更新一下。

总代码

#include<bits/stdc++.h>
#define INF 0x7fffffff
#define M 40010
#define N 20010
using namespace std;
typedef pair<int,int> Pair;
struct node
{
    int from,to,value,next;
}e[M];
struct seg
{
    int sum,mins,maxs,mark;
}tree[4*N];
int tot,st[M],n,m,siz[N],son[N],fa[N],pre[3*N],top[N],dfs_id[N],dep[N],val[N];
char ch;
void add(int x,int y,int z)
{
    e[++tot].to=y;
    e[tot].from=x;
    e[tot].value=z;
    e[tot].next=st[x];
    st[x]=tot;
}
void dfs_1(int x,int f)
{
    siz[x]=1;
    fa[x]=f;
    for (int i=st[x];i;i=e[i].next)
        if (e[i].to!=f)
        {
            dep[e[i].to]=dep[x]+1;
            dfs_1(e[i].to,x);
            siz[x]+=siz[e[i].to];
            if (siz[e[i].to]>siz[son[x]])
                son[x]=e[i].to;
        }
}
int tot2=0;
void dfs_2(int now,int tp)
{
    pre[++tot2]=now;
    dfs_id[now]=tot2;
    top[now]=tp;
    if (son[now])
        dfs_2(son[now],tp);
    for (int i=st[now];i;i=e[i].next)
        if (e[i].to!=son[now] && e[i].to!=fa[now])
            dfs_2(e[i].to,e[i].to);
}
void re(int &a,int &b,int &c){a=-a,b=-b,c=-c;}
void pushdown(int now)
{
    if (tree[now].mark==0) return;
    tree[now<<1].mark^=1;
    tree[now<<1|1].mark^=1;
    swap(tree[now<<1].mins,tree[now<<1].maxs);
    re(tree[now<<1].sum,tree[now<<1].mins,tree[now<<1].maxs);
    swap(tree[now<<1|1].mins,tree[now<<1|1].maxs);
    re(tree[now<<1|1].sum,tree[now<<1|1].mins,tree[now<<1|1].maxs);
    tree[now].mark=0;
}
void pushup(int now)
{
    tree[now].maxs=max(tree[now<<1].maxs,tree[now<<1|1].maxs);
    tree[now].mins=min(tree[now<<1].mins,tree[now<<1|1].mins);
    tree[now].sum=tree[now<<1].sum+tree[now<<1|1].sum;
}
void build(int rt,int l,int r)
{
    if (l==r)
        tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]];
    else
    {
        int mid=(l+r)/2;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
        pushup(rt);
    }
}
void update(int rt,int l,int r,int pos,int x)
{
    if (l==r)
    {
        tree[rt].mark=0;
        tree[rt].sum=tree[rt].maxs=tree[rt].mins=x;
        return;
    }
    pushdown(rt);
    int mid=(r+l)/2;
    if (mid>=pos)
        update(rt<<1,l,mid,pos,x);
    else
        update(rt<<1|1,mid+1,r,pos,x);
    pushup(rt);
}
void res(int rt,int l,int r,int L,int R)
{
    if (L<=l && r<=R)
    {
        tree[rt].mark^=1;
        swap(tree[rt].maxs,tree[rt].mins);
        re(tree[rt].maxs,tree[rt].mins,tree[rt].sum);
        return;
    }
    pushdown(rt);
    int mid=(l+r)/2;
    if (mid>=L)
        res(rt<<1,l,mid,L,R);
    if (mid<R)
        res(rt<<1|1,mid+1,r,L,R);
    pushup(rt);
}
int get_max(int rt,int l,int r,int L,int R)
{
    if (L<=l && r<=R)
        return tree[rt].maxs;
    pushdown(rt);
    int ans=-INF,mid=(r+l)/2;
    if (mid>=L)
        ans=max(ans,get_max(rt<<1,l,mid,L,R));
    if (mid<R)
        ans=max(ans,get_max(rt<<1|1,mid+1,r,L,R));
    return ans;
}
int get_min(int rt,int l,int r,int L,int R)
{
    if (L<=l && r<=R)
        return tree[rt].mins;
    pushdown(rt);
    int ans=INF,mid=(r+l)/2;
    if (mid>=L)
        ans=min(ans,get_min(rt<<1,l,mid,L,R));
    if (mid<R)
        ans=min(ans,get_min(rt<<1|1,mid+1,r,L,R));
    return ans;
}
int get_sum(int rt,int l,int r,int L,int R)
{
    if (L<=l && r<=R)
        return tree[rt].sum;
    pushdown(rt);
    int ans=0,mid=(r+l)/2;
    if (mid>=L)
        ans+=get_sum(rt<<1,l,mid,L,R);
    if (mid<R)
        ans+=get_sum(rt<<1|1,mid+1,r,L,R);
    return ans;
}
void doit()
{
    int l,r,f1,f2,sum=0,maxs=-INF,mins=INF;
    scanf("%d%d",&l,&r),l++,r++;
    f1=top[l],f2=top[r];
    while(f1!=f2)
    {
        if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r);
        if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]);
        else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]);
        else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins);
        else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs);
        l=fa[f1],f1=top[l];
    }
    if (dep[l]>dep[r]) swap(l,r);
    if (l==r)
    {
        if (ch=='S') printf("%d\n",sum);
        if (ch=='I') printf("%d\n",mins);
        if (ch=='A') printf("%d\n",maxs);
    }
    else
    {
        l=son[l];
        if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]);
        else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum);
        else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins);
        else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs);
    }
}
main()
{
    scanf("%d",&n);
    int x,y,z;
    for (int i=1;i<n;i++)
        scanf("%d%d%d",&x,&y,&z),x++,y++,
        add(x,y,z),add(y,x,z);
    dfs_1(1,0);
    dfs_2(1,1);
    for (int i=1;i<=tot;i++)
    {
        if (dep[e[i].from]>dep[e[i].to])
            swap(e[i].from,e[i].to);
        val[e[i].to]=e[i].value;
    }
    build(1,1,tot2);
    scanf("%d",&m);
    while(m--)
    {
        ch=getchar();
        while(ch!='N'&&ch!='S'&&ch!='M'&&ch!='C')
            ch=getchar();
        if (ch=='C')
        {
            scanf("%d%d",&x,&y);
            update(1,1,tot2,dfs_id[e[x<<1].to],y);
        }
        else if (ch=='S') getchar(),getchar(),doit();
        else if (ch=='N') doit();
        else if (ch=='M') ch=getchar(),getchar(),doit();
    }
}

还有一种情况就是操作子树

其实这个更简单

我们可以观察一下一棵树的dfs序

images

我们可以观察到子树是在dfs序上连续的一段:dfs_id[i]+1到dfs_id[i]+siz[i]-1

然后直接线段树就行了

例题

树链剖分模版

#include<cstdio>
#include<iostream>
#define ls k*2
#define rs k*2+1
using namespace std;
int n,m,l,hs,nl,lfs,root,mod;
int a,b,c,d,ans,com;
int s[100010],h[100010];
struct node
{
    int k,f,d,sz,ws,p,t;
}p[100010];
struct nate
{
    int s,n;
}e[200010];
struct tree
{
    int l,r,s,f;
}t[400010];
inline void in(int &ans){ans=0;bool p=false;char ch=getchar();while((ch>'9' || ch<'0')&&ch!='-') ch=getchar();if(ch=='-') p=true,ch=getchar();while(ch<='9'&&ch>='0') ans=ans*10+ch-'0',ch=getchar();if(p) ans=-ans;}
void add(int x,int y)
{
    e[++hs]=(nate){y,h[x]};h[x]=hs;
}
void pushdown(int k)
{
    t[ls].f=(t[ls].f+t[k].f)%mod;
    t[ls].s+=(t[ls].r-t[ls].l+1)*t[k].f%mod;
    t[rs].f=(t[rs].f+t[k].f)%mod;
    t[rs].s+=(t[rs].r-t[rs].l+1)*t[k].f%mod;
    t[k].f=0;
}
void build(int k,int l,int r)
{
    t[k].l=l;t[k].r=r;
    if(l==r){t[k].s=s[++nl];return;}
    int mid=(l+r)/2;
    build(ls,l,mid);
    build(rs,mid+1,r);
    t[k].s=t[ls].s+t[rs].s;
}
void change(int k,int l,int r,int v)
{
    if(t[k].l==l&&t[k].r==r)
    {
        t[k].f=(t[k].f+v)%mod;
        t[k].s+=(t[k].r-t[k].l+1)*v%mod;
        return;
    }
    if(t[k].f) pushdown(k);
    int mid=(t[k].l+t[k].r)/2;
    if(l<=mid) change(ls,l,min(r,mid),v);
    if(r>mid) change(rs,max(l,mid+1),r,v);
    t[k].s=(t[ls].s+t[rs].s)%mod;
}
int query(int k,int l,int r)
{
    if(t[k].l==l&&t[k].r==r) return t[k].s;
    if(t[k].f) pushdown(k);
    int mid=(t[k].l+t[k].r)/2,ans=0;
    if(l<=mid) ans+=query(ls,l,min(r,mid))%mod;
    if(r>mid) ans+=query(rs,max(l,mid+1),r)%mod;
    return ans%mod;
}
void dfs1(int k,int f,int d)
{
    p[k].f=f;p[k].d=d;p[k].sz=1;
    for(int i=h[k];i;i=e[i].n)
    if(e[i].s!=f){
        dfs1(e[i].s,k,d+1);
        p[k].sz+=p[e[i].s].sz;
        if(p[e[i].s].sz>p[p[k].ws].sz) p[k].ws=e[i].s;
    }
}
void dfs2(int k)
{
    s[++l]=p[k].k;p[k].p=l;
    if(p[k].ws)
    {
        p[p[k].ws].t=p[k].t;
        dfs2(p[k].ws);
    }
    for(int i=h[k];i;i=e[i].n)
        if(e[i].s!=p[k].ws&&e[i].s!=p[k].f)
        {
            p[e[i].s].t=e[i].s;
            dfs2(e[i].s);
        }
}
int main()
{
    in(n),in(m),in(root),in(mod);
    for(int i=1;i<=n;i++)
        in(p[i].k);
    for(int i=1;i<n;i++)
        in(a),in(b),add(a,b),add(b,a);
    dfs1(root,root,1);
    dfs2(root);
    build(1,1,l);
    while(m--)
    {
        in(com);
        if(com==1)
        {
            in(b),in(c),in(d);d%=mod;
            for(;p[b].t!=p[c].t;b=p[p[b].t].f)
            {
                if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
                change(1,p[p[b].t].p,p[b].p,d);
            }
            if(p[b].d>p[c].d) swap(b,c);
            change(1,p[b].p,p[c].p,d);
        }
        if(com==2)
        {
            in(b),in(c);ans=0;
            for(;p[b].t!=p[c].t;b=p[p[b].t].f)
            {
                if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
                ans+=query(1,p[p[b].t].p,p[b].p),ans%=mod;
            }
            if(p[b].d>p[c].d) swap(b,c);
            ans+=query(1,p[b].p,p[c].p),ans%=mod;
            printf("%d\n",ans);
        }
        if(com==3)
        {
            in(b),in(c),c%=mod;
            change(1,p[b].p,p[b].p+p[b].sz-1,c);
        }
        if(com==4)
        {
            in(b),ans=0;
            ans=query(1,p[b].p,p[b].p+p[b].sz-1)%mod;
            printf("%d\n",ans);
        }
    }
}

The end.