树链剖分算法总结
树链剖分是一种把树剖分成重链和轻链,并用dfs序储存在线段树中的算法。它可以方便的处理树上路径和子树的问题。把树上数据存在线段树中的思想值得思考。
何为树链剖分?树链,就是树上路径,剖分,就是把树链剖分成轻链和重链。
记siz[v]表示以v为根的子树的节点数,dep[v]表示v的深度,top[v]表示v所在的重链的顶端节点,fa[v]表示v的父亲,son[v]表示重儿子,dfs_id[v]v的dfs序。
先介绍几个概念:
这样,很显然的我们就能发现
1.如果(v,u)为轻边,则siz[u] * 2 < siz[v];
2.从根到某一点的路径上轻链、重链的个数都不大于logn。
这两个很好的性质就可以在logn的复杂度下遍历任意一个路径。我们可以两个点同时向上跳,假如是重链就跳到top,不是就跳到父亲。直到跳到两点的top是同一个。跳的同时就可以用线段树维护一下极值、求和啥的。
图片来自网络
如何实现呢?
我们可以通过两个dfs实现
第一个
void dfs_1(int x,int f) { siz[x]=1; fa[x]=f; for (int i=st[x];i;i=e[i].next) if (e[i].to!=f) { dep[e[i].to]=dep[x]+1; dfs_1(e[i].to,x); siz[x]+=siz[e[i].to]; if (siz[e[i].to]>siz[son[x]]) son[x]=e[i].to; } }
在这个dfs中,可以把siz,fa,dep,son求出来
第二个
int tot2=0; void dfs_2(int now,int tp) { pre[++tot2]=now; dfs_id[now]=tot2; top[now]=tp; if (son[now]) dfs_2(son[now],tp); for (int i=st[now];i;i=e[i].next) if (e[i].to!=son[now] && e[i].to!=fa[now]) dfs_2(e[i].to,e[i].to); } //基本思想就是现在是重链上的话,就用原来的top,不是重链上的就传自己作为top
这个dfs可以把top,dfs_id求出来
pre就是dfs_id的反函数。。它是当构造线段树时候用的。
构造线段树
void build(int rt,int l,int r) { if (l==r) tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]]; else { int mid=(l+r)/2; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } }
val数字存的是每个点的权值。因为线段树用的是dfs序,我们直接调用反函数就可以知道权值
例题是这道
Description
Ray 乐忠于旅游,这次他来到了T 城。T 城是一个水上城市,一共有 N 个景点,有些景点之间会用一座桥连接。为了方便游客到达每个景点但又为了节约成本,T 城的任意两个景点之间有且只有一条路径。换句话说, T 城中只有N − 1 座桥。Ray 发现,有些桥上可以看到美丽的景色,让人心情愉悦,但有些桥狭窄泥泞,令人烦躁。于是,他给每座桥定义一个愉悦度w,也就是说,Ray 经过这座桥会增加w 的愉悦度,这或许是正的也可能是负的。有时,Ray 看待同一座桥的心情也会发生改变。现在,Ray 想让你帮他计算从u 景点到v 景点能获得的总愉悦度。有时,他还想知道某段路上最美丽的桥所提供的最大愉悦度,或是某段路上最糟糕的一座桥提供的最低愉悦度。
Input
输入的第一行包含一个整数N,表示T 城中的景点个数。景点编号为 0…N − 1。接下来N − 1 行,每行三个整数u、v 和w,表示有一条u 到v,使 Ray 愉悦度增加w 的桥。桥的编号为1…N − 1。|w| <= 1000。输入的第N + 1 行包含一个整数M,表示Ray 的操作数目。接下来有M 行,每行描述了一个操作,操作有如下五种形式: C i w,表示Ray 对于经过第i 座桥的愉悦度变成了w。 N u v,表示Ray 对于经过景点u 到v 的路径上的每一座桥的愉悦度都变成原来的相反数。 SUM u v,表示询问从景点u 到v 所获得的总愉悦度。 MAX u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最大愉悦度。 MIN u v,表示询问从景点u 到v 的路径上的所有桥中某一座桥所提供的最小愉悦度。测试数据保证,任意时刻,Ray 对于经过每一座桥的愉悦度的绝对值小于等于1000。
Output
对于每一个询问(操作S、MAX 和MIN),输出答案。
Sample Input
3
0 1 1
1 2 2
8
SUM 0 2
MAX 0 2
N 0 1
SUM 0 2
MIN 0 2
C 1 3
SUM 0 2
MAX 0 2
Sample Output
3
2
1
-1
5
3
HINT
一共有10 个数据,对于第i (1 <= i <= 10) 个数据, N = M = i * 2000。
这题有四个操作,区间最大,区间最小,区间和,区间反转
我们该如何操作呢?
void doit() { int l,r,f1,f2,sum=0,maxs=-INF,mins=INF; scanf("%d%d",&l,&r),l++,r++; f1=top[l],f2=top[r]; while(f1!=f2) { if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r); if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs); l=fa[f1],f1=top[l]; } if (dep[l]>dep[r]) swap(l,r); if (l==r) { if (ch=='S') printf("%d\n",sum); if (ch=='I') printf("%d\n",mins); if (ch=='A') printf("%d\n",maxs); } else { l=son[l]; if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs); } }
首先,找到这两个节点,把高度低的那个往上跳,跳的时候操作一下这个链,最后直到top一样或到同一个点
注意一下,假如不在一个点的话最后还要更新一下。
总代码
#include<bits/stdc++.h> #define INF 0x7fffffff #define M 40010 #define N 20010 using namespace std; typedef pair<int,int> Pair; struct node { int from,to,value,next; }e[M]; struct seg { int sum,mins,maxs,mark; }tree[4*N]; int tot,st[M],n,m,siz[N],son[N],fa[N],pre[3*N],top[N],dfs_id[N],dep[N],val[N]; char ch; void add(int x,int y,int z) { e[++tot].to=y; e[tot].from=x; e[tot].value=z; e[tot].next=st[x]; st[x]=tot; } void dfs_1(int x,int f) { siz[x]=1; fa[x]=f; for (int i=st[x];i;i=e[i].next) if (e[i].to!=f) { dep[e[i].to]=dep[x]+1; dfs_1(e[i].to,x); siz[x]+=siz[e[i].to]; if (siz[e[i].to]>siz[son[x]]) son[x]=e[i].to; } } int tot2=0; void dfs_2(int now,int tp) { pre[++tot2]=now; dfs_id[now]=tot2; top[now]=tp; if (son[now]) dfs_2(son[now],tp); for (int i=st[now];i;i=e[i].next) if (e[i].to!=son[now] && e[i].to!=fa[now]) dfs_2(e[i].to,e[i].to); } void re(int &a,int &b,int &c){a=-a,b=-b,c=-c;} void pushdown(int now) { if (tree[now].mark==0) return; tree[now<<1].mark^=1; tree[now<<1|1].mark^=1; swap(tree[now<<1].mins,tree[now<<1].maxs); re(tree[now<<1].sum,tree[now<<1].mins,tree[now<<1].maxs); swap(tree[now<<1|1].mins,tree[now<<1|1].maxs); re(tree[now<<1|1].sum,tree[now<<1|1].mins,tree[now<<1|1].maxs); tree[now].mark=0; } void pushup(int now) { tree[now].maxs=max(tree[now<<1].maxs,tree[now<<1|1].maxs); tree[now].mins=min(tree[now<<1].mins,tree[now<<1|1].mins); tree[now].sum=tree[now<<1].sum+tree[now<<1|1].sum; } void build(int rt,int l,int r) { if (l==r) tree[rt].sum=tree[rt].maxs=tree[rt].mins=val[pre[r]]; else { int mid=(l+r)/2; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } } void update(int rt,int l,int r,int pos,int x) { if (l==r) { tree[rt].mark=0; tree[rt].sum=tree[rt].maxs=tree[rt].mins=x; return; } pushdown(rt); int mid=(r+l)/2; if (mid>=pos) update(rt<<1,l,mid,pos,x); else update(rt<<1|1,mid+1,r,pos,x); pushup(rt); } void res(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) { tree[rt].mark^=1; swap(tree[rt].maxs,tree[rt].mins); re(tree[rt].maxs,tree[rt].mins,tree[rt].sum); return; } pushdown(rt); int mid=(l+r)/2; if (mid>=L) res(rt<<1,l,mid,L,R); if (mid<R) res(rt<<1|1,mid+1,r,L,R); pushup(rt); } int get_max(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].maxs; pushdown(rt); int ans=-INF,mid=(r+l)/2; if (mid>=L) ans=max(ans,get_max(rt<<1,l,mid,L,R)); if (mid<R) ans=max(ans,get_max(rt<<1|1,mid+1,r,L,R)); return ans; } int get_min(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].mins; pushdown(rt); int ans=INF,mid=(r+l)/2; if (mid>=L) ans=min(ans,get_min(rt<<1,l,mid,L,R)); if (mid<R) ans=min(ans,get_min(rt<<1|1,mid+1,r,L,R)); return ans; } int get_sum(int rt,int l,int r,int L,int R) { if (L<=l && r<=R) return tree[rt].sum; pushdown(rt); int ans=0,mid=(r+l)/2; if (mid>=L) ans+=get_sum(rt<<1,l,mid,L,R); if (mid<R) ans+=get_sum(rt<<1|1,mid+1,r,L,R); return ans; } void doit() { int l,r,f1,f2,sum=0,maxs=-INF,mins=INF; scanf("%d%d",&l,&r),l++,r++; f1=top[l],f2=top[r]; while(f1!=f2) { if (dep[f1]<dep[f2]) swap(f1,f2),swap(l,r); if (ch=='N') res(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[f1],dfs_id[l]); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[f1],dfs_id[l]),mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[f1],dfs_id[l]),maxs); l=fa[f1],f1=top[l]; } if (dep[l]>dep[r]) swap(l,r); if (l==r) { if (ch=='S') printf("%d\n",sum); if (ch=='I') printf("%d\n",mins); if (ch=='A') printf("%d\n",maxs); } else { l=son[l]; if (ch=='N') res(1,1,tot2,dfs_id[l],dfs_id[r]); else if (ch=='S') sum+=get_sum(1,1,tot2,dfs_id[l],dfs_id[r]),printf("%d\n",sum); else if (ch=='I') mins=min(get_min(1,1,tot2,dfs_id[l],dfs_id[r]),mins),printf("%d\n",mins); else if (ch=='A') maxs=max(get_max(1,1,tot2,dfs_id[l],dfs_id[r]),maxs),printf("%d\n",maxs); } } main() { scanf("%d",&n); int x,y,z; for (int i=1;i<n;i++) scanf("%d%d%d",&x,&y,&z),x++,y++, add(x,y,z),add(y,x,z); dfs_1(1,0); dfs_2(1,1); for (int i=1;i<=tot;i++) { if (dep[e[i].from]>dep[e[i].to]) swap(e[i].from,e[i].to); val[e[i].to]=e[i].value; } build(1,1,tot2); scanf("%d",&m); while(m--) { ch=getchar(); while(ch!='N'&&ch!='S'&&ch!='M'&&ch!='C') ch=getchar(); if (ch=='C') { scanf("%d%d",&x,&y); update(1,1,tot2,dfs_id[e[x<<1].to],y); } else if (ch=='S') getchar(),getchar(),doit(); else if (ch=='N') doit(); else if (ch=='M') ch=getchar(),getchar(),doit(); } }
还有一种情况就是操作子树
其实这个更简单
我们可以观察一下一棵树的dfs序
我们可以观察到子树是在dfs序上连续的一段:dfs_id[i]+1到dfs_id[i]+siz[i]-1
然后直接线段树就行了
例题
#include<cstdio> #include<iostream> #define ls k*2 #define rs k*2+1 using namespace std; int n,m,l,hs,nl,lfs,root,mod; int a,b,c,d,ans,com; int s[100010],h[100010]; struct node { int k,f,d,sz,ws,p,t; }p[100010]; struct nate { int s,n; }e[200010]; struct tree { int l,r,s,f; }t[400010]; inline void in(int &ans){ans=0;bool p=false;char ch=getchar();while((ch>'9' || ch<'0')&&ch!='-') ch=getchar();if(ch=='-') p=true,ch=getchar();while(ch<='9'&&ch>='0') ans=ans*10+ch-'0',ch=getchar();if(p) ans=-ans;} void add(int x,int y) { e[++hs]=(nate){y,h[x]};h[x]=hs; } void pushdown(int k) { t[ls].f=(t[ls].f+t[k].f)%mod; t[ls].s+=(t[ls].r-t[ls].l+1)*t[k].f%mod; t[rs].f=(t[rs].f+t[k].f)%mod; t[rs].s+=(t[rs].r-t[rs].l+1)*t[k].f%mod; t[k].f=0; } void build(int k,int l,int r) { t[k].l=l;t[k].r=r; if(l==r){t[k].s=s[++nl];return;} int mid=(l+r)/2; build(ls,l,mid); build(rs,mid+1,r); t[k].s=t[ls].s+t[rs].s; } void change(int k,int l,int r,int v) { if(t[k].l==l&&t[k].r==r) { t[k].f=(t[k].f+v)%mod; t[k].s+=(t[k].r-t[k].l+1)*v%mod; return; } if(t[k].f) pushdown(k); int mid=(t[k].l+t[k].r)/2; if(l<=mid) change(ls,l,min(r,mid),v); if(r>mid) change(rs,max(l,mid+1),r,v); t[k].s=(t[ls].s+t[rs].s)%mod; } int query(int k,int l,int r) { if(t[k].l==l&&t[k].r==r) return t[k].s; if(t[k].f) pushdown(k); int mid=(t[k].l+t[k].r)/2,ans=0; if(l<=mid) ans+=query(ls,l,min(r,mid))%mod; if(r>mid) ans+=query(rs,max(l,mid+1),r)%mod; return ans%mod; } void dfs1(int k,int f,int d) { p[k].f=f;p[k].d=d;p[k].sz=1; for(int i=h[k];i;i=e[i].n) if(e[i].s!=f){ dfs1(e[i].s,k,d+1); p[k].sz+=p[e[i].s].sz; if(p[e[i].s].sz>p[p[k].ws].sz) p[k].ws=e[i].s; } } void dfs2(int k) { s[++l]=p[k].k;p[k].p=l; if(p[k].ws) { p[p[k].ws].t=p[k].t; dfs2(p[k].ws); } for(int i=h[k];i;i=e[i].n) if(e[i].s!=p[k].ws&&e[i].s!=p[k].f) { p[e[i].s].t=e[i].s; dfs2(e[i].s); } } int main() { in(n),in(m),in(root),in(mod); for(int i=1;i<=n;i++) in(p[i].k); for(int i=1;i<n;i++) in(a),in(b),add(a,b),add(b,a); dfs1(root,root,1); dfs2(root); build(1,1,l); while(m--) { in(com); if(com==1) { in(b),in(c),in(d);d%=mod; for(;p[b].t!=p[c].t;b=p[p[b].t].f) { if(p[p[b].t].d<p[p[c].t].d) swap(b,c); change(1,p[p[b].t].p,p[b].p,d); } if(p[b].d>p[c].d) swap(b,c); change(1,p[b].p,p[c].p,d); } if(com==2) { in(b),in(c);ans=0; for(;p[b].t!=p[c].t;b=p[p[b].t].f) { if(p[p[b].t].d<p[p[c].t].d) swap(b,c); ans+=query(1,p[p[b].t].p,p[b].p),ans%=mod; } if(p[b].d>p[c].d) swap(b,c); ans+=query(1,p[b].p,p[c].p),ans%=mod; printf("%d\n",ans); } if(com==3) { in(b),in(c),c%=mod; change(1,p[b].p,p[b].p+p[b].sz-1,c); } if(com==4) { in(b),ans=0; ans=query(1,p[b].p,p[b].p+p[b].sz-1)%mod; printf("%d\n",ans); } } }
The end.