树套树


树套树

在学习 二维线段树 时,介绍了线段树套线段树。

同时确保您已经学会 平衡树

P3380 【模板】二逼平衡树(树套树)

  • 查询 $k$ 在区间内的排名

  • 查询区间内排名为 $k$ 的值

  • 修改某一位值上的数值

  • 查询 $k$ 在区间内的前驱(前驱定义为严格小于 $x$,且最大的数,若不存在输出 $-2147483647$)

  • 查询 $k$ 在区间内的后继(后继定义为严格大于 $x$,且最小的数,若不存在输出 $2147483647$)

题目让我们维护的 $5$ 个操作中,但如果没有要求区间,那么就是平衡树模板题,但是要求了区间后,我们就要找区间之间的关系,可以用线段树维护区间之间直接的关系。和二维线段树相同的,先开一个线段树,然后每个节点维护一个平衡树。然后就是基础的模板操作了(只是难写了亿点点)。

代码

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
int r_r(){//快读 
	int k=0,f=1;
	char c=getchar();
	while(!isdigit(c)){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(isdigit(c)){
		k=(k<<1)+(k<<3)+(c^48);
		c=getchar();
	}
	return k*f;
}
const int o_o=5e4+10;
const int m_a=2147483647;//一定设到极限大(0x3f3f3f3f 会被卡) 
struct sp{
	int s_z;//树的大小 
	int n;//值相同节点数量 
	int v_l;//节点价值 
	int s_n[2];//左右儿子 
	int f_a;//父节点 
}t_s[o_o*40];//平衡树 
struct tr{
	int l;//左儿子 
	int r;//右儿子 
	int g_g;//平衡树根节点 
}t_r[o_o<<2];//线段树 
int a_a[o_o];//原序列 
int n,m,x_p;
int l_l(int k){//左子树 
	return k<<1;
}
int r_r(int k){//右子树 
	return k<<1|1;
}
void u_p(int x){//更新子树大小 
	t_s[x].s_z=t_s[t_s[x].s_n[0]].s_z+t_s[t_s[x].s_n[1]].s_z+t_s[x].n;
} 
void t_n(int x){//旋转平衡树 
	int f=t_s[x].f_a;
	int f_f=t_s[f].f_a;
	int k_k=(t_s[f].s_n[1]==x);//判断节点在父节点的方位 
	
	//旋转操作 
	t_s[f_f].s_n[(t_s[f_f].s_n[1]==f)]=x;
	t_s[x].f_a=f_f;
	t_s[t_s[x].s_n[k_k^1]].f_a=f;
	t_s[f].s_n[k_k]=t_s[x].s_n[k_k^1];
	t_s[x].s_n[k_k^1]=f;
	t_s[f].f_a=x;
	
	//更新新子节点 
	u_p(f);
	//更新新父节点 
	u_p(x);
}
void s_p(int x,int g_g,int k){
	while(t_s[x].f_a!=g_g){//没有旋转到根 
		int f=t_s[x].f_a;//父节点 
		int f_f=t_s[f].f_a;//爷节点 
		if(f_f!=g_g)(t_s[f_f].s_n[1]==f)^(t_s[f].s_n[1]==x)?t_n(f):t_n(x);
		//同方位儿子转节点父亲,否则转接节点 
		
		t_n(x);//旋转节点 
	}
	if(g_g==0)t_r[k].g_g=x;//更新根节点 
}
int yv(int v,int f_a){//初始化节点 
	t_s[++x_p].v_l=v;//初值 
	t_s[x_p].f_a=f_a;//祖先 
	t_s[x_p].n=1;//初始化相同节点数量 
	u_p(x_p);//更新子树大小 
	return x_p;
}
void a_d(int v,int k){
	int n_n=t_r[k].g_g;//记录根节点 
	int f_a=0;//初始化父节点 
	if(!n_n){//初始化平衡树的根 
		n_n=yv(v,0);//初始化节点 
		t_r[k].g_g=n_n;//更新根 
		return ;
	}
	while(n_n&&(t_s[n_n].v_l!=v)){//找到与当前值相等的点 
		f_a=n_n;
		if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
		else n_n=t_s[n_n].s_n[0];
	}
	if(v==t_s[n_n].v_l&&n_n)t_s[n_n].n++;//找到节点相同的点 
	else{
		n_n=yv(v,f_a);//初始化新节点 
		if(f_a){//判断左右儿子 
			if(t_s[f_a].v_l<v)t_s[f_a].s_n[1]=n_n;
			else t_s[f_a].s_n[0]=n_n;
		}
	}
	s_p(n_n,0,k);//旋转平衡树
}
void f_i(int v,int k){
	int n_n=t_r[k].g_g;//根节点 
	if(!n_n)return ;
	
	//找排名 
	while(t_s[n_n].s_n[t_s[n_n].v_l<v]&&t_s[n_n].v_l!=v){//子树有节点,并且找前驱所以值不能相等 
		if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
		else n_n=t_s[n_n].s_n[0];
	}
	s_p(n_n,0,k);//旋转平衡树 
}
int n_t(int x,int b_b,int k){//b_b 0 后继,b_b 1 前驱 
	f_i(x,k);//找 x 的排名 
	int n_n=t_r[k].g_g;//记录根节点 
	if((b_b&&t_s[n_n].v_l<x)||(!b_b&&t_s[n_n].v_l>x))return n_n;//达到边界,找到目标 
	n_n=t_s[n_n].s_n[b_b^1];//跳过边界,往回找 
	while(t_s[n_n].s_n[b_b])n_n=t_s[n_n].s_n[b_b];//逼近目标 
	return n_n;
}
void d_l(int x,int k){
	int n_n=t_r[k].g_g;//记录根节点 
	int q_q=n_t(x,1,k);//前驱 
	int h_j=n_t(x,0,k);//后继 
	s_p(h_j,0,k);//将后继变为根节点 
	s_p(q_q,h_j,k);//将前驱变为根节点子节点 
	int k_k=t_s[q_q].s_n[1];//目标节点 
	if(t_s[k_k].n>1)--t_s[k_k].n,s_p(k_k,0,k);//有多个相同值,减去一个并将目标节点转到根节点 
	else t_s[q_q].s_n[1]=0;//清空节点 
	u_p(q_q);//更新节点信息 
}
void b_t(int k,int l,int r){
	//控制边界 
	a_d(m_a,k);//加最大节点 
	a_d(-m_a,k);//加最小节点 
	
	if(l==r)return ;//叶子节点 
	int m_i=(l+r)>>1;
	b_t(l_l(k),l,m_i);//左子树 
	b_t(r_r(k),m_i+1,r);//右子树 
}
void s_d(int k,int l,int r,int i,int v_l){
	int m_i=(l+r)>>1;
	a_d(v_l,k);//节点平衡树加点 
	if(l==r)return ;//到叶子节点返回 
	if(m_i>=i)s_d(l_l(k),l,m_i,i,v_l);//左子树 
	else s_d(r_r(k),m_i+1,r,i,v_l);//右子树 
} 
int s_p(int k,int l,int r,int i,int x,int y){
	if(l>y||r<x)return 0;//超过边界 
	if(l>=x&&r<=y){//在范围内 
		f_i(i,k);//找排名 
		int n_n=t_r[k].g_g;
		
		//根据子树大小输出排名 
		if(t_s[n_n].v_l>=i)return t_s[t_s[n_n].s_n[0]].s_z-1;
		else return t_s[t_s[n_n].s_n[0]].s_z+t_s[n_n].n-1;
	}
	int m_i=(l+r)>>1;
	return s_p(l_l(k),l,m_i,i,x,y)+s_p(r_r(k),m_i+1,r,i,x,y);//统计左右子树排名 
}
void s_g(int k,int l,int r,int k_l,int v_l){
	d_l(a_a[k_l],k);//删旧点 
	a_d(v_l,k);//补新点 
	if(l==r&&l==k_l){//达到目标节点 
		a_a[k_l]=v_l;//更新 
		return ;
	}
	int m_i=(l+r)>>1;
	if(m_i>=k_l)s_g(l_l(k),l,m_i,k_l,v_l);//左子树 
	else s_g(r_r(k),m_i+1,r,k_l,v_l);//右子树 
}
int s_q(int k,int l,int r,int x,int y,int i){
	if(l>y||r<x)return -m_a;//不在范围内 
	if(l>=x&&r<=y)return t_s[n_t(i,1,k)].v_l;//在范围内 
	int m_i=(l+r)>>1;
	return max(s_q(l_l(k),l,m_i,x,y,i),s_q(r_r(k),m_i+1,r,x,y,i));//返回最大值(越大越逼近) 
}
int s_h(int k,int l,int r,int x,int y,int i){
	if(l>y||r<x)return m_a;//不在范围内 
	if(l>=x&&r<=y)return t_s[n_t(i,0,k)].v_l;//在范围内 
	int m_i=(l+r)>>1;
	return min(s_h(l_l(k),l,m_i,x,y,i),s_h(r_r(k),m_i+1,r,x,y,i));//返回最小值(越小越逼近) 
}
int s_k(int x,int y,int i){
	int l=0,r=1e8,m_i,a_s;
	while(l<=r){//二分找值 
		m_i=(l+r)>>1;
		int b_b=s_p(1,1,n,m_i,x,y)+1;//查数的排名 
		if(b_b>i)r=m_i-1;//超过目标排名 
		else l=m_i+1,a_s=m_i;//记录目前情况 
	}
	return a_s;
}
int main(){
	n=r_r(),m=r_r();
	b_t(1,1,n);//建树 
	for(int i=1;i<=n;++i)a_a[i]=r_r(),s_d(1,1,n,i,a_a[i]);//加点 
	for(int i=1;i<=m;++i){
		int op=r_r(),l=r_r(),r=r_r(),k;
		if(op==1)k=r_r(),printf("%d\n",s_p(1,1,n,k,l,r)+1);//查排名 
		if(op==2)k=r_r(),printf("%d\n",s_k(l,r,k));//查值 
		if(op==3)s_g(1,1,n,l,r);//修改 
		if(op==4)k=r_r(),printf("%d\n",s_q(1,1,n,l,r,k));//前驱 
		if(op==5)k=r_r(),printf("%d\n",s_h(1,1,n,l,r,k));//后继 
	}
	return 0;
}

注意最导致初始化的时候要保证值足够大。

会发现不开 $O_2$ 只能过 $3,4$ 个点,开了 $O_2$ 仍有一个点会被卡掉。那就只能玄学优化了。

我们将建树的过程更改:

void b_t(int k,int l,int r){
	//控制边界 
	a_d(m_a,k);//加最大节点 
	a_d(-m_a,k);//加最小节点 
	
    for(int i=l;i<=r;i++)a_d(a_a[i],k);//读入范围内节点 
	if(l==r)return ;//叶子节点 
	int m_i=(l+r)>>1;
	b_t(k<<1,l,m_i);//左子树 
	b_t(k<<1|1,m_i+1,r);//右子树 
}

直接将每个线段树节点的值范围办函的点全部读入。

再加两个宏:

#define l_s(x)t_s[x].s_n[0]
#define r_s(x)t_s[x].s_n[1]

来访问左右儿子。

注意多次调用函数,会使效率降低,所以删去

int l_l(int k){//左子树 
	return k<<1;
}
int r_r(int k){//右子树 
	return k<<1|1;
}

直接计算。

还有判断“优化”:

if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
else n_n=t_s[n_n].s_n[0];

可以写成:

n_n=t_s[n_n].s_n[t_s[n_n].v_l<v];

最后函数前加上 inline

记得开 $O_2$

AC 代码

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
#define il inline
int r_r(){//快读 
	long long k=0,f=1;
	char c=getchar();
	while(!isdigit(c)){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(isdigit(c)){
		k=(k<<1)+(k<<3)+(c^48);
		c=getchar();
	}
	return k*f;
}
const int N=5e4+5;
#define l_s(x)t_s[x].s_n[0]
#define r_s(x)t_s[x].s_n[1]
const int m_a=2147483647;
struct ts{
    int s_z,n,v_l,s_n[2],f_a;
}t_s[N*50];
struct t_r{
    int l,r,g_g;
}t_r[N*4];
int a_a[N],n,m,x_p;
il void u_p(int x){
    t_s[x].s_z=t_s[l_s(x)].s_z+t_s[r_s(x)].s_z+t_s[x].n;
} 
il void t_n(int x){
    int f=t_s[x].f_a;
	int f_f=t_s[f].f_a;
	int b_b=(t_s[f].s_n[1]==x);
    t_s[f_f].s_n[(t_s[f_f].s_n[1]==f)]=x;
	t_s[x].f_a=f_f;
    t_s[t_s[x].s_n[b_b^1]].f_a=f;
	t_s[f].s_n[b_b]=t_s[x].s_n[b_b^1];
    t_s[x].s_n[b_b^1]=f;
	t_s[f].f_a=x;
    u_p(f);
	u_p(x);
}
il void s_p(int x,int g_g,int k){
    while(t_s[x].f_a!=g_g){
        int f=t_s[x].f_a,f_f=t_s[f].f_a;
        if(f_f!=g_g)(t_s[f_f].s_n[1]==f)^(t_s[f].s_n[1]==x)?t_n(x):t_n(f);
        t_n(x);
    }
    if(g_g==0)t_r[k].g_g=x;
}
il int yv(int v,int f_a){
    t_s[++x_p].v_l=v;
	t_s[x_p].f_a=f_a;
	t_s[x_p].n=1;
    u_p(x_p);
	return x_p;
}
il void a_d(int x,int k){
    int n_n=t_r[k].g_g;
	int f_a=0;
    if(!n_n){
		n_n=yv(x,0);
		t_r[k].g_g=n_n;
		return ;
	}
    while(n_n&&(t_s[n_n].v_l != x))f_a=n_n,n_n=t_s[n_n].s_n[t_s[n_n].v_l<x];
    if(x==t_s[n_n].v_l&&n_n)t_s[n_n].n ++;
    else {
        n_n=yv(x,f_a);
        if(f_a)t_s[f_a].s_n[t_s[f_a].v_l<x]=n_n;
    }
    s_p(n_n,0,k );
}
il void f_i(int x,int k){
    int n_n=t_r[k].g_g;
	if(!n_n)return ;
    while(t_s[n_n].s_n[t_s[n_n].v_l<x]&&t_s[n_n].v_l!=x)
        n_n=t_s[n_n].s_n[t_s[n_n].v_l<x];
    s_p(n_n,0,k);
}
il int n_t(int x,int b_b,int k){
    f_i(x,k);
	int n_n=t_r[k].g_g;
    if((b_b&&t_s[n_n].v_l<x)||(!b_b&&t_s[n_n].v_l>x))return n_n;
    n_n=t_s[n_n].s_n[b_b^1];
    while(t_s[n_n].s_n[b_b])n_n=t_s[n_n].s_n[b_b];
    return n_n;
}
il void d_l(int x,int k){
    int n_n=t_r[k].g_g;
	int q_q=n_t(x,1,k);
	int h_j=n_t(x,0,k);
    s_p(h_j,0,k);
	s_p(q_q,h_j,k);
    int k_l=t_s[q_q].s_n[1];
    if(t_s[k_l].n>1){
    	--t_s[k_l].n;
		s_p(k_l,0,k);
	}else t_s[q_q].s_n[1]=0;
    u_p(q_q);
}
void b_t(int k,int l,int r){
    a_d(m_a,k),a_d(-m_a,k);
    for(int i=l;i<=r;i++)a_d(a_a[i],k);
    if(l==r)return;
    int m_i=(l+r)>>1;
	b_t(k*2,l,m_i);
	b_t(k*2+1,m_i+1,r);
}
il int s_p(int k,int l,int r,int i,int x,int y){
    if(l>y||r<x)return 0;
    if(l>=x&&r<=y){
        f_i(i,k);
		int n_n=t_r[k].g_g;
        if(t_s[n_n].v_l>=i)return t_s[l_s(n_n)].s_z-1;
        else return t_s[l_s(n_n)].s_z+t_s[n_n].n-1;
    }
    int m_i=(l+r)>>1;
    return s_p(k*2,l,m_i,i,x,y)+s_p(k*2+1,m_i+1,r,i,x,y);
}
il void s_g(int k,int l,int r,int k_l,int v_l){
    d_l(a_a[k_l],k);
	a_d(v_l,k);
    if(l==r&&l==k_l){
		a_a[k_l]=v_l;
		return ;
	}
    int m_i=(l+r)>>1;
    if(m_i>=k_l)s_g(k*2,l,m_i,k_l,v_l);
    else s_g(k*2+1,m_i+1,r,k_l,v_l);
}
il int s_q(int k,int l,int r,int x,int y,int i){ 
    if(l>y||r<x)return -m_a;
    if(l>=x&&r<=y)return t_s[n_t(i,1,k)].v_l;
    int m_i=(l+r)>>1;
    return max(s_q(k*2,l,m_i,x,y,i),s_q(k*2+1,m_i+1,r,x,y,i));
}
il int s_h(int k,int l,int r,int x,int y,int i){
    if(l>y||r<x)return m_a;
    if(l>=x&&r<=y)return t_s[n_t(i,0,k)].v_l;
    int m_i=(l+r)>>1;
    return min(s_h(k*2,l,m_i,x,y,i),s_h(k*2+1,m_i+1,r,x,y,i));
}
il int s_k(int x,int y,int i){
    int l=0,r=1e8,m_i,b_b,a_s;
    while(l<=r){
        m_i=(l+r)>>1;
        b_b=s_p(1,1,n,m_i,x,y)+1;
        if(b_b>i)r=m_i-1;
        else l=m_i+1,a_s=m_i;
    }
    return a_s;
}
int main(){
    n=r_r(),m=r_r();
    int op,l,r,k;
    for(int i=1;i<=n;++i)a_a[i]=r_r();
    b_t(1,1,n);
    for(int i=1;i<=m;++i){
        op=r_r(),l=r_r(),r=r_r();
        if(op==1)k=r_r(),printf("%d\n",s_p(1,1,n,k,l,r)+1);
        if(op==2)k=r_r(),printf("%d\n",s_k(l,r,k));
        if(op==3)s_g(1,1,n,l,r);
        if(op==4)k=r_r(),printf("%d\n",s_q(1,1,n,l,r,k));
        if(op==5)k=r_r(),printf("%d\n",s_h(1,1,n,l,r,k));
    }
    return 0;
}

本题可以用分块的方法写,而且跑的非常快。虽然数据点卡常,但是为了练手,建议写一写,还可以树状数组套平衡树,这里不再赘述,感兴趣可以自己试一试。


文章作者: 王大神——A001
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 王大神——A001 !
  目录