首页 > 笔记 > 线段树高阶总结

线段树高阶总结

第二篇总结。。。

先讲一些小技巧

1.用位运算代替乘除2

x>>1   //x/2
x<<1   //x*2

具体见 位运算技巧

2.用define节省代码量

#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r

调用时就可以用lson,rson代替了

具体见下方代码

3.build时直接读入

具体见代码

基本操作

复习一下

(例子为区间加)

0.上推 下推

void pushup(int rt)
{
    seg[rt]=seg[rt<<1]+seg[rt<<1|1];
}
void pushdown(int rt,int x)
{
    if (mark[rt])
    {
        seg[rt<<1]+=mark[rt]*(x-x/2);
        seg[rt<<1|1]+=mark[rt]*(x/2);
        mark[rt<<1]+=mark[rt];
        mark[rt<<1|1]+=mark[rt];
        mark[rt]=0;
    }
}

1.建立线段树

void build(int rt,int l,int r)
{
    if (l==r)
        scanf("%d",seg[rt]);
    else
    {
        int mid=(l+r)>>1;
        build(lson);
        build(rson);
        pushup(rt);
    }
}

2.查询区间

int query(int rt,int l,int r,int L,int R)
{
    if (l>=L && r<=R)
        return seg[rt];
    else
    {
        pushdown(rt,r-l+1);
        int mid=(r+l)>>1,ans=0;
        if (mid>=L)
            ans+=query(lson,L,R);
        if (mid<R)
            ans+=query(rson,L,R);
        return ans;
    }
}

3.更新区间

void update(int rt,int l,int r,int L,int R,int x)
{
    if (l>=L && r<=R)
        seg[rt]+=(r-l+1)*x,mark[rt]+=x;
    else
    {
        pushdown(rt,r-l+1);
        int mid=(r-l)>>1;
        if (mid>=L)
            update(lson,L,R,x);
        if (mid<R)
            update(rson,L,R,x);
        pushup(rt);
    }
}

例题

poj3486

#include <cstdio>
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define Maxn 500000
#define ll long long
using namespace std;
ll seg[Maxn];
ll mark[Maxn];
void pushup(ll rt)
{
    seg[rt]=seg[rt<<1]+seg[rt<<1|1];
}
void pushdown(ll rt,ll x)
{
    if (mark[rt])
    {
        mark[rt<<1]+=mark[rt];
        mark[rt<<1|1]+=mark[rt];
        seg[rt<<1]+=mark[rt]*(x-x/2);
        seg[rt<<1|1]+=mark[rt]*(x/2);
        mark[rt]=0;
    }
}
void build(ll rt,ll l,ll r)
{
    if (l==r)
        scanf("%I64d",&seg[rt]);
    else
    {
        ll mid=(l+r)/2;
        build(lson);
        build(rson);
        pushup(rt);
    }
}
ll query(ll rt,ll l,ll r,ll L,ll R)
{
    if (l>=L && r<=R)
        return seg[rt];
    else
    {
        pushdown(rt,r-l+1);
        ll mid=(r+l)/2,ans=0;
        if (mid>=L)
            ans+=query(lson,L,R);
        if (mid<R)
            ans+=query(rson,L,R);
        return ans;
    }
}
void update(ll rt,ll l,ll r,ll L,ll R,ll x)
{
    if (l>=L && r<=R)
        seg[rt]+=(r-l+1)*x,mark[rt]+=x;
    else
    {
        pushdown(rt,r-l+1);
        ll mid=(r+l)/2;
        if (mid>=L)
            update(lson,L,R,x);
        if (mid<R)
            update(rson,L,R,x);
        pushup(rt);
    }
}
main()
{
    ll n,m;
    char s[20];
    ll x,y,z;
    while(~scanf("%I64d%I64d",&n,&m))
    {
        build(1,1,n);
        while(m--)
        {
            scanf("%s",s);
            if(s[0] == 'C')
            {
                scanf("%I64d%I64d%I64d",&x,&y,&z);
                update(1,1,n,x,y,z);
            }
            else
            {
                scanf("%I64d%I64d",&x,&y);
                printf("%I64dn",query(1,1,n,x,y));
            }
        }
    }
}

下一个操作:区间乘

就是对于一个数列,可以在区间上乘以一个数

原来的数列是

a1+a2+a3+a4

那么我们给每个数乘上x

a1x+a2x+a3x+a4x

可以得到

x(a1+a2+a3+a4)

就相当于原来的区间和乘x

假如原来已经有乘法标记了,那就变成这样

xy(a1+a2+a3+a4)

(原来的为y,现在的为x)

直接乘上x即可

假如原来有加法标记了

(a1+y)x+(a2+y)x+(a2+y)x+(a2+y)x

=xyn+x(a1+a2+a3+a4)

n为区间数的数量

所以程序如下

bzoj 1798

#include <cstdio>
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define Maxn 500000
#define ll long long
#define lld lld
using namespace std;
ll seg[Maxn];
ll marka[Maxn];
ll markm[Maxn];
ll mod;
void pushup(ll rt)
{
    seg[rt]=(seg[rt<<1]+seg[rt<<1|1])%mod;
}
void pushdown(ll rt,ll x)
{
    marka[rt<<1]=(marka[rt]+marka[rt<<1]*markm[rt])%mod;
    marka[rt<<1|1]=(marka[rt]+marka[rt<<1|1]*markm[rt])%mod;

    markm[rt<<1]=(markm[rt]*markm[rt<<1])%mod;
    markm[rt<<1|1]=(markm[rt]*markm[rt<<1|1])%mod;

    seg[rt<<1]=(marka[rt]*(x-x/2)+seg[rt<<1]*markm[rt])%mod;
    seg[rt<<1|1]=(marka[rt]*(x/2)+seg[rt<<1|1]*markm[rt])%mod;

    marka[rt]=0;
    markm[rt]=1;
}
void build(ll rt,ll l,ll r)
{
    markm[rt]=1;
    if (l==r)
        scanf("%lld",&seg[rt]);
    else
    {
        ll mid=(l+r)/2;
        build(lson);
        build(rson);
        pushup(rt);
    }
}
ll query(ll rt,ll l,ll r,ll L,ll R)
{
    if (l>=L && r<=R)
        return seg[rt];
    else
    {
        pushdown(rt,r-l+1);
        ll mid=(r+l)/2,ans=0;
        if (mid>=L)
            ans=(ans+query(lson,L,R))%mod;
        if (mid<R)
            ans=(ans+query(rson,L,R))%mod;
        return ans%mod;
    }
}
void update_a(ll rt,ll l,ll r,ll L,ll R,ll x)
{
    if (l>=L && r<=R)
        seg[rt]=(seg[rt]+(r-l+1)*x)%mod,marka[rt]=(x+marka[rt])%mod;
    else
    {
        pushdown(rt,r-l+1);
        ll mid=(r+l)/2;
        if (mid>=L)
            update_a(lson,L,R,x);
        if (mid<R)
            update_a(rson,L,R,x);
        pushup(rt);
    }
}
void update_m(ll rt,ll l,ll r,ll L,ll R,ll x)
{
    if (l>=L && r<=R)
        seg[rt]=(seg[rt]*x)%mod,marka[rt]=(x*marka[rt])%mod,markm[rt]=(x*markm[rt])%mod;
    else
    {
        pushdown(rt,r-l+1);
        ll mid=(r+l)/2;
        if (mid>=L)
            update_m(lson,L,R,x);
        if (mid<R)
            update_m(rson,L,R,x);
        pushup(rt);
    }
}
main()
{
    ll n,m,x,y,z,c;
    scanf("%lld%lld",&n,&mod);
    {
        build(1,1,n);
        scanf("%lld",&m);
        while(m--)
        {
            scanf("%lld",&c);
            if(c==1)
            {
                scanf("%lld%lld%lld",&x,&y,&z);
                update_m(1,1,n,x,y,z);
            }
            else
            if(c==2)
            {
                scanf("%lld%lld%lld",&x,&y,&z);
                update_a(1,1,n,x,y,z);
            }
            else
            if (c==3)
            {
                scanf("%lld%lld",&x,&y);
                printf("%lldn",query(1,1,n,x,y));
            }
        }
    }
}

现在来点更高阶的

方差luogu1417

题意:对于一个数量,求区间的方差与平均数

images

题解:

我们把方差公式展开

images

所以只需要维护一个区间平方和和区间和

当我们更新一个区间加时

images

所以我们只需要维护一个mark就可以了

代码:

#include <cstdio>
#include <iostream>
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define Maxn 300010
using namespace std;
double sega[Maxn],segb[Maxn];
double mark[Maxn];
void pushup(int x)
{
    sega[x]=sega[x<<1]+sega[x<<1|1];
    segb[x]=segb[x<<1]+segb[x<<1|1];
}
void pushdown(int rt,int x)
{
    if (mark[rt])
    {
        segb[rt<<1]+=2*mark[rt]*sega[rt<<1]+(x-x/2)*mark[rt]*mark[rt];
        segb[rt<<1|1]+=2*mark[rt]*sega[rt<<1|1]+(x/2)*mark[rt]*mark[rt];

        sega[rt<<1]+=(x-x/2)*mark[rt];
        sega[rt<<1|1]+=(x/2)*mark[rt];

        mark[rt<<1]+=mark[rt];
        mark[rt<<1|1]+=mark[rt];

        mark[rt]=0;
    }
}
void build(int rt,int l,int r)
{
    if (l==r)
        cin>>sega[rt],segb[rt]=sega[rt]*sega[rt];
    else
    {
        int mid=(l+r)/2;
        build(lson);
        build(rson);
        pushup(rt);
    }
}
double query_a(int rt,int l,int r,int L,int R)
{
    //--L--l--r--R--
    if (l>=L && r<=R)
        return sega[rt];
    else
    {
        pushdown(rt,r-l+1);
        int mid=(r+l)/2;
        double ret=0;
        if (mid>=L)
            ret+=query_a(lson,L,R);
        if (mid<R)
            ret+=query_a(rson,L,R);
        return ret;
    }
}
double query_b(int rt,int l,int r,int L,int R)
{
    //--L--l--r--R--
    if (l>=L && r<=R)
        return segb[rt];
    else
    {
        pushdown(rt,r-l+1);
        int mid=(r+l)/2;
        double ret=0;
        if (mid>=L)
            ret+=query_b(lson,L,R);
        if (mid<R)
            ret+=query_b(rson,L,R);
        return ret;
    }
}
void update(int rt,int l,int r,int L,int R,double x)
{
    if (l>=L && r<=R)
        mark[rt]+=x,segb[rt]+=2*x*sega[rt]+x*x*(r-l+1),sega[rt]+=(r-l+1)*x;
    else
    {
        pushdown(rt,r-l+1);
        int mid=(r+l)/2;
        if (mid>=L)
            update(lson,L,R,x);
        if (mid<R)
            update(rson,L,R,x);
        pushup(rt);
    }
}
double sqr(double x)
{
    return x*x;
}
main()
{
    int n,m,x,y,c;
    double z;
    scanf("%d %d",&n,&m);
    build(1,1,n);
    for (int i=1;i<=m;i++)
    {
        scanf("%d",&c);
        if (c==2)
            scanf("%d%d",&x,&y),printf("%.4lfn",query_a(1,1,n,x,y)/(y-x+1));
        if (c==1)
            scanf("%d%d",&x,&y),cin>>z,update(1,1,n,x,y,z);
        if (c==3)
        {
            scanf("%d%d",&x,&y);
            double sum1=query_b(1,1,n,x,y)/(y-x+1),sum2=query_a(1,1,n,x,y)/(y-x+1);
            double ans=sum1-sum2*sum2;
            printf("%.4lfn",ans);
        }
    }
}

最后一道,写完都虚了。。。

hdu4578

这道题坑在有三种询问:set , add , mul。所以lazy标记要有三个,如果三个标记同时出现的处理方法——当更新set操作时,就把add标记和mul标记全部取消;当更新mul操作时,如果当前节点add标记存在,就把add标记改为:add * mul。这样的话就可以在PushDown()操作中先执行set,然后mul,最后add。

麻烦在有三种询问:和 , 平方和 , 立方和。对于set和mul操作来说,这三种询问都比较好弄。

对于add操作,和的话就比较好弄,按照正常方法就可以;

平方和这样来推:(a + c)2 = a2 + c2 + 2ac  , 即seg2[rt] = seg2[rt] + (r – l + 1) * c * c + 2 * seg1[rt] * c;

立方和这样推:(a + c)3 = a3 + c3 + 3a(a2 + ac) , 即seg3[rt] = seg3[rt] + (r – l + 1) * c * c * c + 3 * c * (seg2[rt] + seg1[rt] * c);

几个注意点:add标记取消的时候是置0,mul标记取消的时候是置1;在PushDown()中也也要注意取消标记,如set操作中取消add和mul,mul操作中更新add; 在add操作中要注意seg3 , seg2 , seg1的先后顺序,一定是先seg3 , 然后seg2 , 最后seg1; int容易爆,还是用long long要保险一点; 最后就是膜运算较多,不要漏掉东西。

#include<cstdio>
using namespace std;
#define LL long long
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
const int mod=10007;
const int maxn=100000+5;
LL add[maxn<<2],set[maxn<<2],mul[maxn<<2];
LL seg1[maxn<<2],seg2[maxn<<2],seg3[maxn<<2];
void PushUp(int rt)
{
    seg1[rt]=(seg1[rt<<1]+seg1[rt<<1|1])%mod;
    seg2[rt]=(seg2[rt<<1]+seg2[rt<<1|1])%mod;
    seg3[rt]=(seg3[rt<<1]+seg3[rt<<1|1])%mod;
}
void build(int rt,int l,int r)
{
    add[rt]=set[rt]=0;
    mul[rt]=1;
    if(l==r)
    {
        seg1[rt]=seg2[rt]=seg3[rt]=0;
        return;
    }
    int mid=(l+r)>>1;
    build(lson);
    build(rson);
    PushUp(rt);
}
void PushDown(int rt,int x)
{
    if(set[rt])
    {
        set[rt<<1]=set[rt<<1|1]=set[rt];
        add[rt<<1]=add[rt<<1|1]=0;
        mul[rt<<1]=mul[rt<<1|1]=1;
        LL tmp=((set[rt]*set[rt])%mod)*set[rt]%mod;
        seg1[rt<<1]=((x-(x>>1))%mod)*(set[rt]%mod)%mod;
        seg1[rt<<1|1]=((x>>1)%mod)*(set[rt]%mod)%mod;
        seg2[rt<<1]=((x-(x>>1))%mod)*((set[rt]*set[rt])%mod)%mod;
        seg2[rt<<1|1]=((x>>1)%mod)*((set[rt]*set[rt])%mod)%mod;
        seg3[rt<<1]=((x-(x>>1))%mod)*tmp%mod;
        seg3[rt<<1|1]=((x>>1)%mod)*tmp%mod;
        set[rt]=0;
    }
    if(mul[rt]!=1)
    {
        mul[rt<<1]=(mul[rt<<1]*mul[rt])%mod;
        mul[rt<<1|1]=(mul[rt<<1|1]*mul[rt])%mod;
        add[rt<<1]=(add[rt<<1]*mul[rt])%mod;
        add[rt<<1|1]=(add[rt<<1|1]*mul[rt])%mod;
        LL tmp=(((mul[rt]*mul[rt])%mod*mul[rt])%mod);
        seg1[rt<<1]=(seg1[rt<<1]*mul[rt])%mod;
        seg1[rt<<1|1]=(seg1[rt<<1|1]*mul[rt])%mod;
        seg2[rt<<1]=(seg2[rt<<1]%mod)*((mul[rt]*mul[rt])%mod)%mod;
        seg2[rt<<1|1]=(seg2[rt<<1|1]%mod)*((mul[rt]*mul[rt])%mod)%mod;
        seg3[rt<<1]=(seg3[rt<<1]%mod)*tmp%mod;
        seg3[rt<<1|1]=(seg3[rt<<1|1]%mod)*tmp%mod;
        mul[rt]=1;
    }
    if(add[rt])
    {
        add[rt<<1]+=add[rt];
        add[rt<<1|1]+=add[rt];
        LL tmp=(add[rt]*add[rt]%mod)*add[rt]%mod;
        seg3[rt<<1]=(seg3[rt<<1]+(tmp*(x-(x>>1))%mod)+3*add[rt]*((seg2[rt<<1]+seg1[rt<<1]*add[rt])%mod))%mod;
        seg3[rt<<1|1]=(seg3[rt<<1|1]+(tmp*(x>>1)%mod)+3*add[rt]*((seg2[rt<<1|1]+seg1[rt<<1|1]*add[rt])%mod))%mod;
        seg2[rt<<1]=(seg2[rt<<1]+((add[rt]*add[rt]%mod)*(x-(x>>1))%mod)+(2*seg1[rt<<1]*add[rt]%mod))%mod;
        seg2[rt<<1|1]=(seg2[rt<<1|1]+(((add[rt]*add[rt]%mod)*(x>>1))%mod)+(2*seg1[rt<<1|1]*add[rt]%mod))%mod;
        seg1[rt<<1]=(seg1[rt<<1]+(x-(x>>1))*add[rt])%mod;
        seg1[rt<<1|1]=(seg1[rt<<1|1]+(x>>1)*add[rt])%mod;
        add[rt]=0;
    }
}
void update(int rt,int l,int r,int L,int R,int c,int ch)
{
    if(L<=l&&R>=r)
    {
        if(ch==3)
        {
            set[rt]=c;
            add[rt]=0;
            mul[rt]=1;
            seg1[rt]=((r-l+1)*c)%mod;
            seg2[rt]=((r-l+1)*((c*c)%mod))%mod;
            seg3[rt]=((r-l+1)*(((c*c)%mod)*c%mod))%mod;
        }
        else
        if(ch==2)
        {
            mul[rt]=(mul[rt]*c)%mod;
                add[rt]=(add[rt]*c)%mod;
            seg1[rt]=(seg1[rt]*c)%mod;
            seg2[rt]=(seg2[rt]*(c*c%mod))%mod;
            seg3[rt]=(seg3[rt]*((c*c%mod)*c%mod))%mod;
        }
        else
        if(ch==1)
        {
            add[rt]+=c;
            LL tmp=(((c*c)%mod*c)%mod*(r-l+1))%mod; //(r-l+1)*c^3
            seg3[rt]=(seg3[rt]+tmp+3*c*((seg2[rt]+seg1[rt]*c)%mod))%mod;
            seg2[rt]=(seg2[rt]+(c*c%mod*(r-l+1)%mod)+2*seg1[rt]*c)%mod;
            seg1[rt]=(seg1[rt]+(r-l+1)*c)%mod;
        }
        return;
    }
    PushDown(rt,r-l+1);
    int mid=(l+r)>>1;
    if (mid>=L)
        update(lson,L,R,c,ch);
    if (mid<R)
        update(rson,L,R,c,ch);
    PushUp(rt);
}
LL query(int rt,int l,int r,int L,int R,int c)
{
    if(L<=l&&R>=r)
        if(c==1)
            return seg1[rt]%mod;
        else if(c==2)
            return seg2[rt]%mod;
        else
            return seg3[rt]%mod;
    PushDown(rt,r-l+1);
    LL mid=(l+r)>>1,ans=0;
    if (mid>=L)
        ans=(ans+query(lson,L,R,c))%mod;
    if (mid<R)
        ans=(ans+query(rson,L,R,c))%mod;
    return ans%mod;
}
main()
{
    int n,m,a,b,c,ch;
    while(~scanf("%d%d",&n,&m))
    {
        if(n==0&&m==0)
            break;
        build(1,1,n);
        while(m--)
        {
            scanf("%d%d%d%d",&ch,&a,&b,&c);
            if(ch!=4)
                update(1,1,n,a,b,c,ch);
            else
                printf("%I64dn",query(1,1,n,a,b,c));
        }
    }
}

The end.