这玩意有毒

线段树 学习笔记

什么是线段树

线段树是一种运用于区间修改和区间查询的数据结构,这棵树的每一个节点,都是一个线段,也就是数组中的一段子数组。由于二叉结构的特性,时间复杂度可以优化到nlogn。

线段树的原理

把1到n分成一定数量的子区间,可证区间数量不超过4*n(线段树是一棵完全二叉树,总节点数应该是最后一层节点的2倍,但是为了保险设成四倍)。线段树通过对子区间的修改和统计来完成对大区间的修改和统计。
所以,用线段树来统计的东西必须符合区间加法(包括区间乘),即大区间由小区间决定,不然没有办法用线段树。
符合区间加法的例子:

  1. 数字的和
  2. 最大公约数((总gcd=左区间gcd,右区间gcd))
  3. 求max
    其实,是要一个问题能够化成对一些连续点的修改和统计问题,就能用线段树来解决了。

对于懒标记的解释

懒标记是什么

懒标记是在说明,本节点已经被修改过了,但是本节点的子节点仍然需要修改。

为什么要有懒标记

节约时间。如果一次推到底的话,那么和O(n)也没啥区别了。

懒标记怎么用

每递归到一个区间,先把它带着的标记算出来,再给它的子节点打上标记。这样,向上返回的仍然是正确的值,向下的标记能保证结果正确。

线段树的实现

这里以最简单的区间加法为例,即把某个区间都加上x,然后进行查询。

建树

void build(int l,int r,int num)
//当前操作区间的左右区间,num为当前节点的编号
{
    if (l==r)
    {
        sum[num]=a[l];
        return;
    }//节点为单点时直接赋值
    int mid=(l+r)>>1;
    build(l,mid,num<<1);
    build(mid+1,r,num<<1|1);
    sum[num]=sum[num<<1]+sum[num<<1|1];
}

下推标记

void pushdown(int num,int ln,int rn)
//向下传递标记 
{
    if (add[num])
    {
        add[num<<1]+=add[num];
        add[num<<1|1]+=add[num];
        sum[num<<1]+=add[num]*ln;
        sum[num<<1|1]+=add[num]*rn;
        add[num]=0;
    }
}

区间修改

void change(int L,int R,int C,int l,int r,int num)
{
    if (L<=l && r<=R)
    {
        sum[num]+=C*(r-l+1);
        add[num]+=C;
        return;
    }//如果当前区间就在操作区间之内,直接修改
    int m=(l+r)>>1;
    pushdown(num,m-l+1,r-m);//结算当前节点的标记,然后继续往下推
    if (L<=m) change(L,R,C,l,m,num<<1);
    if (R>m) change(L,R,C,m+1,r,num<<1|1);
    sum[num]=sum[num<<1]+sum[num<<1|1];
}

区间查询

long long ques(int L,int R,int l,int r,int num)
{
    if (L<=l && r<=R)
        return sum[num];//当前区间就在查询区间之内,直接返回值			
    int mid=(l+r)>>1;
    pushdown(num,mid-l+1,r-mid);//结算当前节点的标记
    long long ans=0;
    if (L<=mid) ans+=ques(L,R,l,mid,num<<1);
    if (R>mid) ans+=ques(L,R,mid+1,r,num<<1|1);
    return ans;
}

线段树例题

luogu 3373 线段树2

这道题还是很毒瘤的,其中对于两种标记的处理值得推敲。因为两种标记谁先谁后无法确定,所以需要仔细处理,详细见代码注释。

#include<cstdio>
#include<iostream>
#define maxn 100007
using namespace std;
long long sum[maxn<<2],add[maxn<<2],cdd[maxn<<2];
int a[maxn],n,m,number,x,y,k,p;
void build(int l,int r,int num) 
//l,r表示当前节点区间,rt表示当前节点编号
{
    add[num]=0;cdd[num]=1;
    //加标记一开始为0,乘标记一开始为1
    if(l==r)
    {
        sum[num]=a[l];
        return;
    }
    int m=(l+r)>>1;
    build(l,m,num<<1);
    build(m+1,r,num<<1|1);
    sum[num]=(sum[num<<1]+sum[num<<1|1])%p;
} 
void pushdown(int num,int ln,int rn)
//下推标记 
{
    if (cdd[num]!=1)
    {
		cdd[num<<1]=(cdd[num<<1]*cdd[num])%p;
		cdd[num<<1|1]=(cdd[num<<1|1]*cdd[num])%p;
		add[num<<1]=(add[num<<1]*cdd[num])%p;
		add[num<<1|1]=(add[num<<1|1]*cdd[num])%p;
		sum[num<<1]=(sum[num<<1]*cdd[num])%p;
		sum[num<<1|1]=(sum[num<<1|1]*cdd[num])%p;
		//这里乘cdd[num]
		//因为某个点的标记是它的子节点要进行的操作
		//而这个节点的父节点的标记才是这个节点要进行的操作
		//这就是为什么叫懒标记的原因
		cdd[num]=1;
	} 
	if (add[num]!=0)
	{
		add[num<<1]=(add[num<<1]+add[num])%p;
		add[num<<1|1]=(add[num<<1|1]+add[num])%p;
		sum[num<<1]=(sum[num<<1]+ln*add[num])%p;
		sum[num<<1|1]=(sum[num<<1|1]+rn*add[num])%p;
		add[num]=0;
	}
}
void chdate(int L,int R,int c,int l,int r,int num)
//乘法
{
    if(L<=l && r<=R)
    {
        sum[num]=sum[num]*c%p;
        
        cdd[num]=(cdd[num]*c)%p;
        add[num]=(add[num]*c)%p;
        //该区间每个数*c
		//如果它的子区间原先需要每个数乘x,那么现在就需要每个数乘x*c
		//如果它的子区间原先需要每个数加y,那么现在需要每个数加y*c 
		//这个地方是整个程序的灵魂
		//因为这个程序最难的地方是加法标记和乘法标记谁先结算的问题
		//这个地方的存在可以让我们放心地在pushdown中先结算乘法标记 
		return;
    }
    int m=(l+r)>>1;
    pushdown(num,m-l+1,r-m);
    if(L<=m) chdate(L,R,c,l,m,num<<1);
    if(R>m) chdate(L,R,c,m+1,r,num<<1|1);
    sum[num]=(sum[num<<1]+sum[num<<1|1])%p;
}
void update(int L,int R,int c,int l,int r,int num)
//加法 
{
    if(L<=l && r<=R)
    {
        sum[num]=(sum[num]+c*(r-l+1))%p; // 每个数加c,整个区间一共要加c*(l-r+1) 
        add[num]=(add[num]+c)%p; 
        //加一个数不对乘法标记造成影响 
        return;
    }
    int m=(l+r)>>1;
    pushdown(num,m-l+1,r-m); 
    if(L<=m) update(L,R,c,l,m,num<<1);
    if(R>m) update(L,R,c,m+1,r,num<<1|1);
    // 判断左右子树跟[L,R]有无交集,有交集才递归 
    sum[num]=(sum[num<<1]+sum[num<<1|1])%p;
} 
long long query(int L,int R,int l,int r,int num)
{
//区间查询 
    if(L<=l && r<=R)
        return sum[num];
    int m=(l+r)>>1;
    pushdown(num,m-l+1,r-m);
    long long ans=0;
    if(L<=m) ans+=query(L,R,l,m,num<<1)%p;
    if(R>m) ans+=query(L,R,m+1,r,num<<1|1)%p;
    return ans%p;
} 
int main()
{
    scanf("%d%d%d",&n,&m,&p);
    for(int i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    build(1,n,1);
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&number);
        if(number==1)
        {
            scanf("%d%d%d",&x,&y,&k);
            chdate(x,y,k,1,n,1);
        }
        if(number==2)
        {
            scanf("%d%d%d",&x,&y,&k);
            update(x,y,k,1,n,1);
        }
        if(number==3)
        {
            scanf("%d%d",&x,&y);
            printf("%lld\n",query(x,y,1,n,1)%p);
        }
    }
}

luogu 2982 慢下来

这是一道能让你真正理解建树好题
在这里先反思一下自己的错误。之前我对建树的理解一直很模糊,因为好像这个环节是线段树里面最轻松的——其实不然。
这里再详细阐释一下建树。线段树是由一个数组或者说线性的东西变身而来的,所以,线段树的区间里的下标应该是顺序的,即一个节点(也就是一段区间)的两个子树分别代表着这段区间的左边和右边,所以递归的时候才可以放心地采用二分的方式。对于这个题而言,我们要把它建成一棵树,就不能用广搜,而要用深搜,这样才能保证它下标的顺序,而这颗搜索树并不是我们的线段树。线段树还是原汁原味的线段树,这颗搜索树只是确定了线段树里面每个节点所盛放的区间是哪一段区间。理解了这一点,这道题就是一个区间修改+单点查询的线段树。
然后中间出了一个小错误,就是修改(或者查询)的时候应该是所求区间不动,而总区间不断二分,把所求区间凑出来,这个一定要搞清楚。
理解了这两点,这个题就比较好办了。
代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define rint register int
#define inv inline void
#define ini inline int
using namespace std;
struct node
{
	int next,to;
}ljb[400001];
int head[400001],add[400001],n,map[400001],cnt,dfn[400001],tot,size[400001];
inv adedge(int x,int y)
{
	ljb[++cnt].next=head[x];
	ljb[cnt].to=y;
	head[x]=cnt;
}
inv dfs(int x,int fa)
{
	dfn[x]=++tot;
	size[x]=1;
	for (rint i=head[x];i;i=ljb[i].next)
	{
		int y=ljb[i].to;
		if (y!=fa)
		{
			dfs(y,x);
			size[x]+=size[y];
		}
	}
}
inv pushdown(int num)
{
	if (add[num]!=0)
	{
		add[num<<1]+=add[num];
		add[num<<1|1]+=add[num];
		add[num]=0;
	}
}
inv update(int l,int r,int x,int y,int num)
{
	if (x<=l && r<=y)
	{
		add[num]++;
		return;
	}
	pushdown(num);
	int m=l+r>>1;
	if (x<=m) update(l,m,x,y,num<<1);
	if (y>m) update(m+1,r,x,y,num<<1|1);
}
ini ques(int l,int r,int num,int ed)
{
	if (l==r) return add[num];
	pushdown(num);
	int m=l+r>>1;int ans=0;
	if (ed<=m) ans+=ques(l,m,num<<1,ed);
	if (ed>m) ans+=ques(m+1,r,num<<1|1,ed);
	return ans;
}
int main()
{
	scanf("%d",&n);
	for (rint i=1;i<=n-1;i++)
	{
		int x,y;scanf("%d%d",&x,&y);
		adedge(x,y);adedge(y,x);
	}
	dfs(1,0);
	for (rint i=1;i<=n;i++)
	{
		int p;scanf("%d",&p);
		printf("%d\n",ques(1,n,1,dfn[p]));
		update(1,n,dfn[p],dfn[p]+size[p]-1,1);
	}
}

TO BE CONTINUED

赞 赏
蒟蒻的邮箱:xmzl200201@126.com