树剖进阶总结
再来一篇
树链剖分
我们先搞棵树
如果只要子树操作?可以发现子树是DFS序上一段连续区间。。
而且区间的大小就是子树的大小(很明显啊
那么怎么处理路径问题?
我们可以发现dfs序是一些小的路径拼起来的
实际上,每个点第一个被访问的子树就会和它形成一条路径,我们就把这棵树剖成了很多条路径,询问的一条路径就可以分成这些小路径。
我们就是要找一种比较好的方法,让每条询问的路径都分成尽量少的短。
怎么?随机?感觉非常不靠谱。
那就按子树大小加权随机?用不到。直接剖较大的就可以了
一个点的子树最大的儿子叫重儿子,这条边叫做重边。一条由重边组成的链叫做重链。
显然一条路径只会有$O(\log n)$条重边。
然后我们dfs一次求出重儿子,再dfs一遍求出dfs序。这样每个重链在dfs序上就是连续的区间。用数据结构维护下dfs序就可以了。
一般的题用线段树就可以维护了。
例题
就是模版啊。。
#include <cstdio>
#include <algorithm>
#include <queue>
#define M 200010
#define N 200010
#define min(x,y) ((x<y)?(x):(y))
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
using namespace std;
typedef long long LL;
LL inline read()
{
LL x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct node
{
LL to,next;
}e[M];
LL tot,st[M],deep[N],siz[N],fa[N],son[N],val[N],top[N],v[N],id[N],pos[N];
LL n,m,x,y,z,root,mod,com,ans,ql,qr;
void add(LL x,LL y)
{
e[++tot].to=y;
e[tot].next=st[x];
st[x]=tot;
}
LL dfs_id=0;
void dfs1(LL now,LL pre,LL depth)
{
LL maxs=-0x3f3f3f3f;
deep[now]=depth;
siz[now]=1;
fa[now]=pre;
for (LL i=st[now];i;i=e[i].next)
if (e[i].to!=pre)
{
dfs1(e[i].to,now,depth+1);
siz[now]+=siz[e[i].to];
if (maxs<siz[e[i].to])
son[now]=e[i].to,maxs=siz[e[i].to];
}
}
void dfs2(LL now,LL tops)
{
top[now]=tops;
id[++dfs_id]=now;
pos[now]=dfs_id;
val[dfs_id]=v[now];
if (son[now]!=0)
dfs2(son[now],tops);
for (LL i=st[now];i;i=e[i].next)
if (e[i].to!=fa[now]&&e[i].to!=son[now])
dfs2(e[i].to,e[i].to);
}
LL seg[4*N];
LL mark[4*N];
void pushup(LL rt)
{
seg[rt]=(seg[rt<<1]+seg[rt<<1|1]);
}
void pushdown(LL rt,LL x)
{
if (mark[rt])
{
mark[rt<<1]=(mark[rt<<1]+mark[rt]);
mark[rt<<1|1]=(mark[rt<<1|1]+mark[rt]);
seg[rt<<1]=(seg[rt<<1]+mark[rt]*(x-x/2));
seg[rt<<1|1]=(seg[rt<<1|1]+mark[rt]*(x/2));
mark[rt]=0;
}
}
void build(LL rt,LL l,LL r)
{
if (l==r)
seg[rt]=val[l];
else
{
LL mid=(l+r)/2;
build(lson);
build(rson);
pushup(rt);
}
}
LL query(LL rt,LL l,LL r,LL L,LL R)
{
if (l>=L && r<=R)
return seg[rt];
else
{
pushdown(rt,r-l+1);
LL mid=(r+l)/2,x=0;
if (mid>=L)
x+=query(lson,L,R);
if (mid<R)
x+=query(rson,L,R);
return x;
}
}
void update(LL rt,LL l,LL r,LL L,LL R,LL x)
{
if (l>=L && r<=R)
seg[rt]=(seg[rt]+(r-l+1)*x),mark[rt]=(mark[rt]+x);
else
{
pushdown(rt,r-l+1);
LL mid=(r+l)/2;
if (mid>=L)
update(lson,L,R,x);
if (mid<R)
update(rson,L,R,x);
pushup(rt);
}
}
void add_path(LL x,LL y,LL z)
{
while(top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
update(1,1,n,pos[top[x]],pos[x],z);
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
update(1,1,n,pos[y],pos[x],z);
}
LL sum_path(LL x,LL y)
{
LL ans=0;
while(top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
ans=(ans+query(1,1,n,pos[top[x]],pos[x]))%mod;
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
ans=(ans+query(1,1,n,pos[y],pos[x]))%mod;
return ans;
}
void add_sub(LL x,LL y)
{
update(1,1,n,pos[x],pos[x]+siz[x]-1,y);
}
LL sum_sub(LL x)
{
return query(1,1,n,pos[x],pos[x]+siz[x]-1);
}
main()
{
scanf("%d%d%d%d",&n,&m,&root,&mod);
for (LL i=1;i<=n;i++)
scanf("%d",&v[i]);
for (LL i=1;i<=n-1;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
deep[0]=-1;
dfs1(root,0,1);
dfs2(root,root);
build(1,1,n);
for (LL i=1;i<=m;i++)
{
ans=0;
scanf("%d",&com);
if (com==1)
x=read(),y=read(),z=read(),add_path(x,y,z);
else if (com==2)
x=read(),y=read(),printf("%d\n",sum_path(x,y)%mod);
else if (com==3)
x=read(),y=read(),add_sub(x,y);
else if (com==4)
x=read(),printf("%d\n",sum_sub(x)%mod);
}
}
查询的是不包含x的链的权值最大值。
树链剖分后,每条链变成O(logn)个区间,那么未被这个链包含的也有logn个区间,在这logn个区间上做修改即可。
因为有删除操作,每个线段树节点开一个堆维护。
#include <cstdio>
#include <queue>
#include <algorithm>
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define N 200010
#define M 500010
using namespace std;
struct node
{
int to,next;
}e[M];
struct qst
{
int l,r;
}q[N];
int tot,st[M],deep[N],siz[N],son[N],val[N],top[N],v[N],id[N],pos[N],ed[N];
int n,m,x,y,z,root,mod,com,ans,cnt;
struct cmp{bool operator () (int a,int b) {return val[a]<val[b];}};
priority_queue<int,vector<int>,cmp> pri[4*N];
bool comp(qst a,qst b){return a.l<b.l;}
void swap(int &a,int &b){int t=a;a=b;b=t;}
void add(int x,int y)
{
e[++tot].to=y;
e[tot].next=st[x];
st[x]=tot;
}
int fa[N];
void dfs1(int now)
{
int maxs=-1;
siz[now]=1;
for (int i=st[now];i;i=e[i].next)
if (e[i].to!=fa[now])
{
deep[e[i].to]=deep[now]+1;
fa[e[i].to]=now;
dfs1(e[i].to);
siz[now]+=siz[e[i].to];
if (maxs<siz[e[i].to])
maxs=siz[e[i].to],son[now]=e[i].to;
}
}
void dfs2(int now,int tops)
{
id[++cnt]=now;
pos[now]=cnt;
// val[cnt]=v[now];
top[now]=tops;
if (son[now])
dfs2(son[now],tops);
for (int i=st[now];i;i=e[i].next)
if (e[i].to!=fa[now] && e[i].to!=son[now])
dfs2(e[i].to,e[i].to);
}
void insert(int rt,int l,int r,int L,int R,int x)
{
if (L<=l && r<=R)
{
pri[rt].push(x);
return;
}
int mid=(l+r)/2;
if (L<=mid)
insert(lson,L,R,x);
if (R>mid)
insert(rson,L,R,x);
}
int Ans;
void query(int rt,int l,int r,int x)
{
while(!pri[rt].empty())
{
int now=pri[rt].top();
if (ed[now])
pri[rt].pop();
else
{
Ans=max(Ans,val[now]);
break;
}
}
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid)
query(lson,x);
else
query(rson,x);
}
main()
{
scanf("%d%d",&n,&m);
root=1;
for (int i=1;i<n;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
dfs1(root);
dfs2(root,root);
for (int i=1;i<=m;i++)
{
int com;
scanf("%d",&com);
if (com==0)
{
scanf("%d%d%d",&x,&y,&val[i]);
int w=0;
while(top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
q[++w]=(qst){pos[top[x]],pos[x]};
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
q[++w]=(qst){pos[y],pos[x]};
q[++w]=(qst){n+1,n+1};
sort(q+1,q+w+1,comp);
for (int j=1;j<=w;j++)
if (q[j-1].r<q[j].l-1)
insert(1,1,n,q[j-1].r+1,q[j].l-1,i);
}
else if (com==1)
{
scanf("%d",&x);
ed[x]=1;
}
else
{
scanf("%d",&x);
Ans=-1;
query(1,1,n,pos[x]);
printf("%d\n",Ans);
}
}
}
这题是非常有趣的一道题。我们对于每一种宗教开一棵线段树。但是不能开全,就是动态开点的线段树。每次要改的时候就加入新点。(有种主席树即视感
#include <cstdio>
#include <algorithm>
#include <queue>
#define M 200010
#define N 100010
#define min(x,y) ((x<y)?(x):(y))
#define lson ls[rt],l,mid
#define rson rs[rt],mid+1,r
using namespace std;
int inline read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct node
{
int to,next;
}e[M];
int tot,st[M],deep[N],siz[N],fa[N],son[N],val[N],top[N],v[N],id[N],pos[N];
int r[N],rel[N],rt[N];
int n,m,x,y,z,root,mod,com,ans,ql,qr;
void add(int x,int y)
{
e[++tot].to=y;
e[tot].next=st[x];
st[x]=tot;
}
int dfs_id=0;
void dfs1(int now,int pre,int depth)
{
int maxs=-0x3f3f3f3f;
deep[now]=depth;
siz[now]=1;
fa[now]=pre;
for (int i=st[now];i;i=e[i].next)
if (e[i].to!=pre)
{
dfs1(e[i].to,now,depth+1);
siz[now]+=siz[e[i].to];
if (maxs<siz[e[i].to])
son[now]=e[i].to,maxs=siz[e[i].to];
}
}
void dfs2(int now,int tops)
{
top[now]=tops;
id[++dfs_id]=now;
pos[now]=dfs_id;
val[dfs_id]=v[now];
rel[dfs_id]=r[now];
if (son[now]!=0)
dfs2(son[now],tops);
for (int i=st[now];i;i=e[i].next)
if (e[i].to!=fa[now]&&e[i].to!=son[now])
dfs2(e[i].to,e[i].to);
}
int sum[40*N],maxs[40*N],ls[40*N],rs[40*N],cnt;
void add_tree(int &now,int pre,int l,int r,int pos,int v)
{
now=++cnt;
if (l==r)
sum[now]=v,maxs[now]=v;
else
{
int mid=l+r>>1;
if (pos<=mid)
add_tree(ls[now],ls[pre],l,mid,pos,v),rs[now]=rs[pre];
else
add_tree(rs[now],rs[pre],mid+1,r,pos,v),ls[now]=ls[pre];
sum[now]=sum[ls[now]]+sum[rs[now]];
maxs[now]=max(maxs[ls[now]],maxs[rs[now]]);
}
}
int query_s(int rt,int l,int r,int L,int R)
{
if (!rt) return 0;
if (l>=L && r<=R)
return sum[rt];
else
{
int mid=(r+l)/2,x=0;
if (mid>=L)
x+=query_s(lson,L,R);
if (mid<R)
x+=query_s(rson,L,R);
return x;
}
}
int query_m(int rt,int l,int r,int L,int R)
{
if (!rt) return 0;
if (l>=L && r<=R)
return maxs[rt];
else
{
int mid=(r+l)/2,x=0;
if (mid>=L)
x=max(x,query_m(lson,L,R));
if (mid<R)
x=max(x,query_m(rson,L,R));;
return x;
}
}
int sum_path(int x,int y,int re)
{
int ans=0;
while(top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
ans=(ans+query_s(rt[re],1,n,pos[top[x]],pos[x]));
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
ans=(ans+query_s(rt[re],1,n,pos[y],pos[x]));
return ans;
}
int max_path(int x,int y,int re)
{
int ans=0;
while(top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
ans=max(ans,query_m(rt[re],1,n,pos[top[x]],pos[x]));
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
ans=max(ans,query_m(rt[re],1,n,pos[y],pos[x]));
return ans;
}
main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d%d",&v[i],&r[i]);
for (int i=1;i<=n-1;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
deep[0]=-1;
root=1;
dfs1(root,0,1);
dfs2(root,root);
for (int i=1;i<=n;i++)
add_tree(rt[rel[i]],rt[rel[i]],1,n,i,val[i]);
char c1,c2;int a,b;
for (int i=1;i<=m;i++)
{
getchar();
ans=0;
scanf("%c%c%d%d",&c1,&c2,&a,&b);
if (c1=='Q'&&c2=='S')
printf("%d\n",sum_path(a,b,rel[pos[a]]));
else
if (c1=='Q'&&c2=='M')
printf("%d\n",max_path(a,b,rel[pos[a]]));
else
if (c1=='C'&&c2=='C')
{
add_tree(rt[rel[pos[a]]],rt[rel[pos[a]]],1,n,pos[a],0);
rel[pos[a]]=b;
add_tree(rt[b],rt[b],1,n,pos[a],val[pos[a]]);
}
else
if (c1=='C'&&c2=='W')
{
val[pos[a]]=b;
add_tree(rt[rel[pos[a]]],rt[rel[pos[a]]],1,n,pos[a],b);
}
}
}