线段树高阶总结
第二篇总结。。。
先讲一些小技巧
1.用位运算代替乘除2
x>>1 //x/2 x<<1 //x*2
具体见 位运算技巧
2.用define节省代码量
#define lson rt<<1,l,mid #define rson rt<<1|1,mid+1,r
调用时就可以用lson,rson代替了
具体见下方代码
3.build时直接读入
具体见代码
基本操作
复习一下
(例子为区间加)
0.上推 下推
void pushup(int rt) { seg[rt]=seg[rt<<1]+seg[rt<<1|1]; } void pushdown(int rt,int x) { if (mark[rt]) { seg[rt<<1]+=mark[rt]*(x-x/2); seg[rt<<1|1]+=mark[rt]*(x/2); mark[rt<<1]+=mark[rt]; mark[rt<<1|1]+=mark[rt]; mark[rt]=0; } }
1.建立线段树
void build(int rt,int l,int r) { if (l==r) scanf("%d",seg[rt]); else { int mid=(l+r)>>1; build(lson); build(rson); pushup(rt); } }
2.查询区间
int query(int rt,int l,int r,int L,int R) { if (l>=L && r<=R) return seg[rt]; else { pushdown(rt,r-l+1); int mid=(r+l)>>1,ans=0; if (mid>=L) ans+=query(lson,L,R); if (mid<R) ans+=query(rson,L,R); return ans; } }
3.更新区间
void update(int rt,int l,int r,int L,int R,int x) { if (l>=L && r<=R) seg[rt]+=(r-l+1)*x,mark[rt]+=x; else { pushdown(rt,r-l+1); int mid=(r-l)>>1; if (mid>=L) update(lson,L,R,x); if (mid<R) update(rson,L,R,x); pushup(rt); } }
例题
#include <cstdio> #define lson rt<<1,l,mid #define rson rt<<1|1,mid+1,r #define Maxn 500000 #define ll long long using namespace std; ll seg[Maxn]; ll mark[Maxn]; void pushup(ll rt) { seg[rt]=seg[rt<<1]+seg[rt<<1|1]; } void pushdown(ll rt,ll x) { if (mark[rt]) { mark[rt<<1]+=mark[rt]; mark[rt<<1|1]+=mark[rt]; seg[rt<<1]+=mark[rt]*(x-x/2); seg[rt<<1|1]+=mark[rt]*(x/2); mark[rt]=0; } } void build(ll rt,ll l,ll r) { if (l==r) scanf("%I64d",&seg[rt]); else { ll mid=(l+r)/2; build(lson); build(rson); pushup(rt); } } ll query(ll rt,ll l,ll r,ll L,ll R) { if (l>=L && r<=R) return seg[rt]; else { pushdown(rt,r-l+1); ll mid=(r+l)/2,ans=0; if (mid>=L) ans+=query(lson,L,R); if (mid<R) ans+=query(rson,L,R); return ans; } } void update(ll rt,ll l,ll r,ll L,ll R,ll x) { if (l>=L && r<=R) seg[rt]+=(r-l+1)*x,mark[rt]+=x; else { pushdown(rt,r-l+1); ll mid=(r+l)/2; if (mid>=L) update(lson,L,R,x); if (mid<R) update(rson,L,R,x); pushup(rt); } } main() { ll n,m; char s[20]; ll x,y,z; while(~scanf("%I64d%I64d",&n,&m)) { build(1,1,n); while(m--) { scanf("%s",s); if(s[0] == 'C') { scanf("%I64d%I64d%I64d",&x,&y,&z); update(1,1,n,x,y,z); } else { scanf("%I64d%I64d",&x,&y); printf("%I64dn",query(1,1,n,x,y)); } } } }
下一个操作:区间乘
就是对于一个数列,可以在区间上乘以一个数
原来的数列是
a1+a2+a3+a4
那么我们给每个数乘上x
a1x+a2x+a3x+a4x
可以得到
x(a1+a2+a3+a4)
就相当于原来的区间和乘x
假如原来已经有乘法标记了,那就变成这样
xy(a1+a2+a3+a4)
(原来的为y,现在的为x)
直接乘上x即可
假如原来有加法标记了
(a1+y)x+(a2+y)x+(a2+y)x+(a2+y)x
=xyn+x(a1+a2+a3+a4)
n为区间数的数量
所以程序如下
#include <cstdio> #define lson rt<<1,l,mid #define rson rt<<1|1,mid+1,r #define Maxn 500000 #define ll long long #define lld lld using namespace std; ll seg[Maxn]; ll marka[Maxn]; ll markm[Maxn]; ll mod; void pushup(ll rt) { seg[rt]=(seg[rt<<1]+seg[rt<<1|1])%mod; } void pushdown(ll rt,ll x) { marka[rt<<1]=(marka[rt]+marka[rt<<1]*markm[rt])%mod; marka[rt<<1|1]=(marka[rt]+marka[rt<<1|1]*markm[rt])%mod; markm[rt<<1]=(markm[rt]*markm[rt<<1])%mod; markm[rt<<1|1]=(markm[rt]*markm[rt<<1|1])%mod; seg[rt<<1]=(marka[rt]*(x-x/2)+seg[rt<<1]*markm[rt])%mod; seg[rt<<1|1]=(marka[rt]*(x/2)+seg[rt<<1|1]*markm[rt])%mod; marka[rt]=0; markm[rt]=1; } void build(ll rt,ll l,ll r) { markm[rt]=1; if (l==r) scanf("%lld",&seg[rt]); else { ll mid=(l+r)/2; build(lson); build(rson); pushup(rt); } } ll query(ll rt,ll l,ll r,ll L,ll R) { if (l>=L && r<=R) return seg[rt]; else { pushdown(rt,r-l+1); ll mid=(r+l)/2,ans=0; if (mid>=L) ans=(ans+query(lson,L,R))%mod; if (mid<R) ans=(ans+query(rson,L,R))%mod; return ans%mod; } } void update_a(ll rt,ll l,ll r,ll L,ll R,ll x) { if (l>=L && r<=R) seg[rt]=(seg[rt]+(r-l+1)*x)%mod,marka[rt]=(x+marka[rt])%mod; else { pushdown(rt,r-l+1); ll mid=(r+l)/2; if (mid>=L) update_a(lson,L,R,x); if (mid<R) update_a(rson,L,R,x); pushup(rt); } } void update_m(ll rt,ll l,ll r,ll L,ll R,ll x) { if (l>=L && r<=R) seg[rt]=(seg[rt]*x)%mod,marka[rt]=(x*marka[rt])%mod,markm[rt]=(x*markm[rt])%mod; else { pushdown(rt,r-l+1); ll mid=(r+l)/2; if (mid>=L) update_m(lson,L,R,x); if (mid<R) update_m(rson,L,R,x); pushup(rt); } } main() { ll n,m,x,y,z,c; scanf("%lld%lld",&n,&mod); { build(1,1,n); scanf("%lld",&m); while(m--) { scanf("%lld",&c); if(c==1) { scanf("%lld%lld%lld",&x,&y,&z); update_m(1,1,n,x,y,z); } else if(c==2) { scanf("%lld%lld%lld",&x,&y,&z); update_a(1,1,n,x,y,z); } else if (c==3) { scanf("%lld%lld",&x,&y); printf("%lldn",query(1,1,n,x,y)); } } } }
现在来点更高阶的
题意:对于一个数量,求区间的方差与平均数
题解:
我们把方差公式展开
所以只需要维护一个区间平方和和区间和
当我们更新一个区间加时
所以我们只需要维护一个mark就可以了
代码:
#include <cstdio> #include <iostream> #define lson rt<<1,l,mid #define rson rt<<1|1,mid+1,r #define Maxn 300010 using namespace std; double sega[Maxn],segb[Maxn]; double mark[Maxn]; void pushup(int x) { sega[x]=sega[x<<1]+sega[x<<1|1]; segb[x]=segb[x<<1]+segb[x<<1|1]; } void pushdown(int rt,int x) { if (mark[rt]) { segb[rt<<1]+=2*mark[rt]*sega[rt<<1]+(x-x/2)*mark[rt]*mark[rt]; segb[rt<<1|1]+=2*mark[rt]*sega[rt<<1|1]+(x/2)*mark[rt]*mark[rt]; sega[rt<<1]+=(x-x/2)*mark[rt]; sega[rt<<1|1]+=(x/2)*mark[rt]; mark[rt<<1]+=mark[rt]; mark[rt<<1|1]+=mark[rt]; mark[rt]=0; } } void build(int rt,int l,int r) { if (l==r) cin>>sega[rt],segb[rt]=sega[rt]*sega[rt]; else { int mid=(l+r)/2; build(lson); build(rson); pushup(rt); } } double query_a(int rt,int l,int r,int L,int R) { //--L--l--r--R-- if (l>=L && r<=R) return sega[rt]; else { pushdown(rt,r-l+1); int mid=(r+l)/2; double ret=0; if (mid>=L) ret+=query_a(lson,L,R); if (mid<R) ret+=query_a(rson,L,R); return ret; } } double query_b(int rt,int l,int r,int L,int R) { //--L--l--r--R-- if (l>=L && r<=R) return segb[rt]; else { pushdown(rt,r-l+1); int mid=(r+l)/2; double ret=0; if (mid>=L) ret+=query_b(lson,L,R); if (mid<R) ret+=query_b(rson,L,R); return ret; } } void update(int rt,int l,int r,int L,int R,double x) { if (l>=L && r<=R) mark[rt]+=x,segb[rt]+=2*x*sega[rt]+x*x*(r-l+1),sega[rt]+=(r-l+1)*x; else { pushdown(rt,r-l+1); int mid=(r+l)/2; if (mid>=L) update(lson,L,R,x); if (mid<R) update(rson,L,R,x); pushup(rt); } } double sqr(double x) { return x*x; } main() { int n,m,x,y,c; double z; scanf("%d %d",&n,&m); build(1,1,n); for (int i=1;i<=m;i++) { scanf("%d",&c); if (c==2) scanf("%d%d",&x,&y),printf("%.4lfn",query_a(1,1,n,x,y)/(y-x+1)); if (c==1) scanf("%d%d",&x,&y),cin>>z,update(1,1,n,x,y,z); if (c==3) { scanf("%d%d",&x,&y); double sum1=query_b(1,1,n,x,y)/(y-x+1),sum2=query_a(1,1,n,x,y)/(y-x+1); double ans=sum1-sum2*sum2; printf("%.4lfn",ans); } } }
最后一道,写完都虚了。。。
这道题坑在有三种询问:set , add , mul。所以lazy标记要有三个,如果三个标记同时出现的处理方法——当更新set操作时,就把add标记和mul标记全部取消;当更新mul操作时,如果当前节点add标记存在,就把add标记改为:add * mul。这样的话就可以在PushDown()操作中先执行set,然后mul,最后add。
麻烦在有三种询问:和 , 平方和 , 立方和。对于set和mul操作来说,这三种询问都比较好弄。
对于add操作,和的话就比较好弄,按照正常方法就可以;
平方和这样来推:(a + c)2 = a2 + c2 + 2ac , 即seg2[rt] = seg2[rt] + (r – l + 1) * c * c + 2 * seg1[rt] * c;
立方和这样推:(a + c)3 = a3 + c3 + 3a(a2 + ac) , 即seg3[rt] = seg3[rt] + (r – l + 1) * c * c * c + 3 * c * (seg2[rt] + seg1[rt] * c);
几个注意点:add标记取消的时候是置0,mul标记取消的时候是置1;在PushDown()中也也要注意取消标记,如set操作中取消add和mul,mul操作中更新add; 在add操作中要注意seg3 , seg2 , seg1的先后顺序,一定是先seg3 , 然后seg2 , 最后seg1; int容易爆,还是用long long要保险一点; 最后就是膜运算较多,不要漏掉东西。
#include<cstdio> using namespace std; #define LL long long #define lson rt<<1,l,mid #define rson rt<<1|1,mid+1,r const int mod=10007; const int maxn=100000+5; LL add[maxn<<2],set[maxn<<2],mul[maxn<<2]; LL seg1[maxn<<2],seg2[maxn<<2],seg3[maxn<<2]; void PushUp(int rt) { seg1[rt]=(seg1[rt<<1]+seg1[rt<<1|1])%mod; seg2[rt]=(seg2[rt<<1]+seg2[rt<<1|1])%mod; seg3[rt]=(seg3[rt<<1]+seg3[rt<<1|1])%mod; } void build(int rt,int l,int r) { add[rt]=set[rt]=0; mul[rt]=1; if(l==r) { seg1[rt]=seg2[rt]=seg3[rt]=0; return; } int mid=(l+r)>>1; build(lson); build(rson); PushUp(rt); } void PushDown(int rt,int x) { if(set[rt]) { set[rt<<1]=set[rt<<1|1]=set[rt]; add[rt<<1]=add[rt<<1|1]=0; mul[rt<<1]=mul[rt<<1|1]=1; LL tmp=((set[rt]*set[rt])%mod)*set[rt]%mod; seg1[rt<<1]=((x-(x>>1))%mod)*(set[rt]%mod)%mod; seg1[rt<<1|1]=((x>>1)%mod)*(set[rt]%mod)%mod; seg2[rt<<1]=((x-(x>>1))%mod)*((set[rt]*set[rt])%mod)%mod; seg2[rt<<1|1]=((x>>1)%mod)*((set[rt]*set[rt])%mod)%mod; seg3[rt<<1]=((x-(x>>1))%mod)*tmp%mod; seg3[rt<<1|1]=((x>>1)%mod)*tmp%mod; set[rt]=0; } if(mul[rt]!=1) { mul[rt<<1]=(mul[rt<<1]*mul[rt])%mod; mul[rt<<1|1]=(mul[rt<<1|1]*mul[rt])%mod; add[rt<<1]=(add[rt<<1]*mul[rt])%mod; add[rt<<1|1]=(add[rt<<1|1]*mul[rt])%mod; LL tmp=(((mul[rt]*mul[rt])%mod*mul[rt])%mod); seg1[rt<<1]=(seg1[rt<<1]*mul[rt])%mod; seg1[rt<<1|1]=(seg1[rt<<1|1]*mul[rt])%mod; seg2[rt<<1]=(seg2[rt<<1]%mod)*((mul[rt]*mul[rt])%mod)%mod; seg2[rt<<1|1]=(seg2[rt<<1|1]%mod)*((mul[rt]*mul[rt])%mod)%mod; seg3[rt<<1]=(seg3[rt<<1]%mod)*tmp%mod; seg3[rt<<1|1]=(seg3[rt<<1|1]%mod)*tmp%mod; mul[rt]=1; } if(add[rt]) { add[rt<<1]+=add[rt]; add[rt<<1|1]+=add[rt]; LL tmp=(add[rt]*add[rt]%mod)*add[rt]%mod; seg3[rt<<1]=(seg3[rt<<1]+(tmp*(x-(x>>1))%mod)+3*add[rt]*((seg2[rt<<1]+seg1[rt<<1]*add[rt])%mod))%mod; seg3[rt<<1|1]=(seg3[rt<<1|1]+(tmp*(x>>1)%mod)+3*add[rt]*((seg2[rt<<1|1]+seg1[rt<<1|1]*add[rt])%mod))%mod; seg2[rt<<1]=(seg2[rt<<1]+((add[rt]*add[rt]%mod)*(x-(x>>1))%mod)+(2*seg1[rt<<1]*add[rt]%mod))%mod; seg2[rt<<1|1]=(seg2[rt<<1|1]+(((add[rt]*add[rt]%mod)*(x>>1))%mod)+(2*seg1[rt<<1|1]*add[rt]%mod))%mod; seg1[rt<<1]=(seg1[rt<<1]+(x-(x>>1))*add[rt])%mod; seg1[rt<<1|1]=(seg1[rt<<1|1]+(x>>1)*add[rt])%mod; add[rt]=0; } } void update(int rt,int l,int r,int L,int R,int c,int ch) { if(L<=l&&R>=r) { if(ch==3) { set[rt]=c; add[rt]=0; mul[rt]=1; seg1[rt]=((r-l+1)*c)%mod; seg2[rt]=((r-l+1)*((c*c)%mod))%mod; seg3[rt]=((r-l+1)*(((c*c)%mod)*c%mod))%mod; } else if(ch==2) { mul[rt]=(mul[rt]*c)%mod; add[rt]=(add[rt]*c)%mod; seg1[rt]=(seg1[rt]*c)%mod; seg2[rt]=(seg2[rt]*(c*c%mod))%mod; seg3[rt]=(seg3[rt]*((c*c%mod)*c%mod))%mod; } else if(ch==1) { add[rt]+=c; LL tmp=(((c*c)%mod*c)%mod*(r-l+1))%mod; //(r-l+1)*c^3 seg3[rt]=(seg3[rt]+tmp+3*c*((seg2[rt]+seg1[rt]*c)%mod))%mod; seg2[rt]=(seg2[rt]+(c*c%mod*(r-l+1)%mod)+2*seg1[rt]*c)%mod; seg1[rt]=(seg1[rt]+(r-l+1)*c)%mod; } return; } PushDown(rt,r-l+1); int mid=(l+r)>>1; if (mid>=L) update(lson,L,R,c,ch); if (mid<R) update(rson,L,R,c,ch); PushUp(rt); } LL query(int rt,int l,int r,int L,int R,int c) { if(L<=l&&R>=r) if(c==1) return seg1[rt]%mod; else if(c==2) return seg2[rt]%mod; else return seg3[rt]%mod; PushDown(rt,r-l+1); LL mid=(l+r)>>1,ans=0; if (mid>=L) ans=(ans+query(lson,L,R,c))%mod; if (mid<R) ans=(ans+query(rson,L,R,c))%mod; return ans%mod; } main() { int n,m,a,b,c,ch; while(~scanf("%d%d",&n,&m)) { if(n==0&&m==0) break; build(1,1,n); while(m--) { scanf("%d%d%d%d",&ch,&a,&b,&c); if(ch!=4) update(1,1,n,a,b,c,ch); else printf("%I64dn",query(1,1,n,a,b,c)); } } }
The end.