二维线段树
二维线段树目前有两种较为常见的写法:四分树,树套树。这种算法空间要求极大。
注意:请务必保证您已经学会了 线段树,并有一定程度自己的见解。
四分树
将线段树从一维变成二维可以加两个对于 $y$ 轴位置描述条件(就像描述 $x$ 轴的两个描述一样)即可。
原来的树的节点 $x$ 轴有两个儿子:左右儿子。现在由于多了一维,儿子的数会变多,但是仍可以描述(复杂至极)。现在我们节点 $y$ 轴看也有两个儿子,那么一共就是 $4$ 个儿子,这些儿子的称呼长度也增加了,为了区分必须要加一维的描述,才能知道说的是哪一个。
需要注意的是:在找中位数的时候(确定范围时,不断二分)我们要两维都满足是叶子,才是真正的叶子。那么就会出现其中一维是叶子结点,但另一维没有达到叶子结点,这是要判断已经确定叶子结点的边界。
每回二分区间:$[l,mid]$ 和 $[mid+1,r]$,但是当一维已经确定是叶子节点时: $l=r$,所以要判断边界,否则会出现类似:$[4,3]$,$[8,7]$,$…$ 一类的情况(别问我是怎么知道的)。这直接就会导致死循环。
代码
四分树
实现:区间修改,区间求和,区间最大值。
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
long long r_r(){//快读
long long k=0,f=1;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c)){
k=(k<<1)+(k<<3)+(c^48);
c=getchar();
}
return k*f;
}
const int o_o=4e3+10;
int n,m;
int m_k;
struct po{
long long m_a;//记录区间最大值
long long s_m;//记录区间和
long long l_n;//懒标记
}t_r[o_o][o_o];//树
long long a_a[o_o][o_o];//矩阵原值
int l_l(int k){//左儿子
return k<<1;
}
int r_r(int k){//右儿子
return k<<1|1;
}
void u_p(int x,int y){//节点最大值,区结和更新
t_r[x][y].s_m=t_r[l_l(x)][l_l(y)].s_m+t_r[l_l(x)][r_r(y)].s_m+t_r[r_r(x)][l_l(y)].s_m+t_r[r_r(x)][r_r(y)].s_m;//更新区间和
t_r[x][y].m_a=max(t_r[l_l(x)][l_l(y)].m_a,t_r[l_l(x)][r_r(y)].m_a);//更新最大值
t_r[x][y].m_a=max(t_r[r_r(x)][l_l(y)].m_a,t_r[x][y].m_a);
t_r[x][y].m_a=max(t_r[r_r(x)][r_r(y)].m_a,t_r[x][y].m_a);
}
void b_t(int l_x,int r_x,int l_y,int r_y,int k_x,int k_y){
if(l_x==r_x&&l_y==r_y){//在目标范围内
t_r[k_x][k_y].s_m=a_a[l_x][l_y];//初始化节点值
t_r[k_x][k_y].m_a=a_a[l_x][l_y];//初始化最大值
return ;
}
m_k=max(m_k,k_x);
m_k=max(m_k,k_y);
int m_x=(l_x+r_x)>>1,m_y=(l_y+r_y)>>1;
//注意 if 特判边界,防止死循环
b_t(l_x,m_x,l_y,m_y,l_l(k_x),l_l(k_y));//x 轴左儿子,y 轴左儿子(左上角)
if(m_y<r_y)b_t(l_x,m_x,m_y+1,r_y,l_l(k_x),r_r(k_y));//x 轴左儿子,y 轴右儿子(右上角)(注意判断边界)
if(m_x<r_x)b_t(m_x+1,r_x,l_y,m_y,r_r(k_x),l_l(k_y));//x 轴右儿子,y 轴右儿子(左下角)(注意判断边界)
if(m_x<r_x&&m_y<r_y)b_t(m_x+1,r_x,m_y+1,r_y,r_r(k_x),r_r(k_y));//x 轴右儿子,y 轴右儿子(右下角)(注意判断边界)
u_p(k_x,k_y);
}
void p_d(int l_x,int r_x,int l_y,int r_y,int k_x,int k_y){
if(!t_r[k_x][k_y].l_n)return ;//没有标记,直接跳过
int m_x=(l_x+r_x)>>1,m_y=(l_y+r_y)>>1;
//更新子节点区间和
t_r[l_l(k_x)][l_l(k_y)].s_m+=t_r[k_x][k_y].l_n*(m_x-l_x+1)*(m_y-l_y+1);
t_r[l_l(k_x)][r_r(k_y)].s_m+=t_r[k_x][k_y].l_n*(m_x-l_x+1)*(r_y-m_y);
t_r[r_r(k_x)][l_l(k_y)].s_m+=t_r[k_x][k_y].l_n*(r_x-m_x)*(m_y-l_y+1);
t_r[r_r(k_x)][r_r(k_y)].s_m+=t_r[k_x][k_y].l_n*(r_x-m_x)*(r_y-m_y);
//更新子节点区间最大值
t_r[l_l(k_x)][l_l(k_y)].m_a+=t_r[k_x][k_y].l_n;
t_r[l_l(k_x)][r_r(k_y)].m_a+=t_r[k_x][k_y].l_n;
t_r[r_r(k_x)][l_l(k_y)].m_a+=t_r[k_x][k_y].l_n;
t_r[r_r(k_x)][r_r(k_y)].m_a+=t_r[k_x][k_y].l_n;
//更新子节点区间懒标记
t_r[l_l(k_x)][l_l(k_y)].l_n+=t_r[k_x][k_y].l_n;
t_r[l_l(k_x)][r_r(k_y)].l_n+=t_r[k_x][k_y].l_n;
t_r[r_r(k_x)][l_l(k_y)].l_n+=t_r[k_x][k_y].l_n;
t_r[r_r(k_x)][r_r(k_y)].l_n+=t_r[k_x][k_y].l_n;
t_r[k_x][k_y].l_n=0;
}
void a_d(int l_x,int r_x,int x_l,int x_r,int l_y,int r_y,int y_l,int y_r,int k_x,int k_y,int v){
if(x_l<=l_x&&r_x<=x_r&&y_l<=l_y&&r_y<=y_r){//在目标范围内
t_r[k_x][k_y].s_m+=v*(r_x-l_x+1)*(r_y-l_y+1);//更新区间和
t_r[k_x][k_y].m_a+=v;//更新区间最大值
t_r[k_x][k_y].l_n+=v;//更新懒标记
return ;
}
p_d(l_x,r_x,l_y,r_y,k_x,k_y);//解放当前点懒标记
int m_x=(l_x+r_x)>>1,m_y=(l_y+r_y)>>1;
//分别更新子节点
if(m_x>=x_l){
if(m_y>=y_l)a_d(l_x,m_x,x_l,x_r,l_y,m_y,y_l,y_r,l_l(k_x),l_l(k_y),v);
if(m_y<y_r)a_d(l_x,m_x,x_l,x_r,m_y+1,r_y,y_l,y_r,l_l(k_x),r_r(k_y),v);
}
if(m_x<x_r){
if(m_y>=y_l)a_d(m_x+1,r_x,x_l,x_r,l_y,m_y,y_l,y_r,r_r(k_x),l_l(k_y),v);
if(m_y<y_r)a_d(m_x+1,r_x,x_l,x_r,m_y+1,r_y,y_l,y_r,r_r(k_x),r_r(k_y),v);
}
u_p(k_x,k_y);//更新节点最大值和区间和
}
long long q_s(int l_x,int r_x,int x_l,int x_r,int l_y,int r_y,int y_l,int y_r,int k_x,int k_y){
long long a_s=0;//初始化
p_d(l_x,r_x,l_y,r_y,k_x,k_y);//释放懒标记
if(x_l<=l_x&&r_x<=x_r&&y_l<=l_y&&r_y<=y_r)return t_r[k_x][k_y].s_m;//返回区间统计的值
int m_x=(l_x+r_x)>>1,m_y=(l_y+r_y)>>1;
//分别统计子节点
if(m_x>=x_l){
if(m_y>=y_l)a_s+=q_s(l_x,m_x,x_l,x_r,l_y,m_y,y_l,y_r,l_l(k_x),l_l(k_y));
if(m_y<y_r&&m_y<r_y)a_s+=q_s(l_x,m_x,x_l,x_r,m_y+1,r_y,y_l,y_r,l_l(k_x),r_r(k_y));
}
if(m_x<x_r&&m_x<r_x){
if(m_y>=y_l)a_s+=q_s(m_x+1,r_x,x_l,x_r,l_y,m_y,y_l,y_r,r_r(k_x),l_l(k_y));
if(m_y<y_r&&m_y<r_y)a_s+=q_s(m_x+1,r_x,x_l,x_r,m_y+1,r_y,y_l,y_r,r_r(k_x),r_r(k_y));
}
return a_s;
}
int q_m(int l_x,int r_x,int x_l,int x_r,int l_y,int r_y,int y_l,int y_r,int k_x,int k_y){
int a_s=-0x3f3f3f3f;//最小化最大值
p_d(l_x,r_x,l_y,r_y,k_x,k_y);//释放懒标记
if(x_l<=l_x&&r_x<=x_r&&y_l<=l_y&&r_y<=y_r)return t_r[k_x][k_y].m_a;//返回区间统计的值
int m_x=(l_x+r_x)>>1,m_y=(l_y+r_y)>>1;
//分别统计子节点
if(m_x>=x_l){
if(m_y>=y_l)a_s=max(a_s,q_m(l_x,m_x,x_l,x_r,l_y,m_y,y_l,y_r,l_l(k_x),l_l(k_y)));
if(m_y<y_r&&m_y<r_y)a_s=max(a_s,q_m(l_x,m_x,x_l,x_r,m_y+1,r_y,y_l,y_r,l_l(k_x),r_r(k_y)));
}
if(m_x<x_r&&m_x<r_x){
if(m_y>=y_l)a_s=max(a_s,q_m(m_x+1,r_x,x_l,x_r,l_y,m_y,y_l,y_r,r_r(k_x),l_l(k_y)));
if(m_y<y_r&&m_y<r_y)a_s=max(a_s,q_m(m_x+1,r_x,x_l,x_r,m_y+1,r_y,y_l,y_r,r_r(k_x),r_r(k_y)));
}
return a_s;
}
int main(){
freopen("rd.txt","r",stdin);//随机数据生成文件
freopen("x_d.txt","w",stdout);//线段树方法结果
n=r_r(),m=r_r();
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)a_a[i][j]=r_r();//读入原始矩阵
b_t(1,n,1,m,1,1);//建树
int q=r_r();
for(int i=1;i<=q;i++){
int op=r_r();
if(op==1){//区间加
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r(),v=r_r();
a_d(1,n,x_1,x_2,1,m,y_1,y_2,1,1,v);
}else if(op==2){//区间求和
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r();
printf("%lld\n",q_s(1,n,x_1,x_2,1,m,y_1,y_2,1,1));
}else {//区间最大值
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r();
printf("%lld\n",q_m(1,n,x_1,x_2,1,m,y_1,y_2,1,1));
}
}
return 0;
}
/*
3 3
1 2 3
4 5 6
7 8 9
7
2 1 1 2 3
3 2 2 3 3
1 1 1 2 2 5
3 1 3 2 3
1 2 2 3 3 -1
2 1 2 2 3
3 2 2 2 3
*/
但是它的复杂度并不优秀。
暴力
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
long long r_r(){//快读
long long k=0,f=1;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c)){
k=(k<<1)+(k<<3)+(c^48);
c=getchar();
}
return k*f;
}
const int o_o=2e3+10;
int n,m;
long long a_a[o_o][o_o];
int main(){
freopen("rd.txt","r",stdin);//随机数据生成文件
freopen("b_l.txt","w",stdout);//暴力方法结果
n=r_r(),m=r_r();
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)a_a[i][j]=r_r();//读入原始矩阵
int q=r_r();
for(int i=1;i<=q;i++){
int op=r_r();
if(op==1){//区间加
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r(),v=r_r();
for(int i=x_1;i<=x_2;i++)
for(int j=y_1;j<=y_2;j++)a_a[i][j]+=v;
}else if(op==2){//区间求和
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r();
long long a_s=0;
for(int i=x_1;i<=x_2;i++)
for(int j=y_1;j<=y_2;j++)a_s+=a_a[i][j];
printf("%lld\n",a_s);
}else {//区间最大值
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r();
long long m_a=-0x3f3f3f3f;
for(int i=x_1;i<=x_2;i++)
for(int j=y_1;j<=y_2;j++)m_a=max(m_a,a_a[i][j]);
printf("%lld\n",m_a);
}
}
return 0;
}
经过测试发现,小数据和暴力时间差不多,大数据空间直接爆炸……但是毕竟还是快一点。
生成数据
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
long long r_r(){//快读
long long k=0,f=1;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c)){
k=(k<<1)+(k<<3)+(c^48);
c=getchar();
}
return k*f;
}
const int o_o=1e3;
mt19937 m_t(time(0));
int main(){
freopen("rd.txt","w",stdout);//生成随机数文件
int n=m_t()%o_o+1,m=m_t()%o_o+1;
cout<<n<<" "<<m<<endl;//矩形大小
for(int i=1;i<=n;i++){//初始矩阵
for(int j=1;j<=m;j++){
cout<<int(m_t()%o_o)-o_o/2<<" ";
}
cout<<endl;
}
int q=m_t()%o_o+1;//询问次数
cout<<q<<endl;
while(q--){
int op=m_t()%3+1,x_1=m_t()%n+1,x_2=m_t()%n+1,y_1=m_t()%m+1,y_2=m_t()%m+1;
if(x_2<x_1)swap(x_1,x_2);
if(y_2<y_1)swap(y_1,y_2);
cout<<op<<" "<<x_1<<" "<<y_1<<" "<<x_2<<" "<<y_2<<" ";//询问区间
if(op==1)cout<<int(m_t()%100)-50;//修改的值
cout<<endl;
}
return 0;
}
对拍:
#include<iostream>
#include<cstdio>
#include<windows.h>//用来调用 system 的头文件
using namespace std;
int main(){
/*
while(1){
system("rd.exe > rd.txt");//将随机数据放入 rd.txt 中
system("x_d.exe < rd.txt > x_d.txt");//rd.txt 中的数据放到 x_d.exe 中运行,结果放入 x_d.txt
system("b_l.exe < rd.txt > b_l.txt");//rd.txt 中的数据放到 b_l.exe 中运行,结果放入 b_l.txt
if(system("fc x_d.txt b_l.txt"))break;//比对输出
}
*/
while(1){
//分别运行代码对应可执行文件
system("rd.exe");
system("x_d.exe");
system("b_l.exe");
if(system("fc x_d.txt b_l.txt"))break;//比对输出
}
return 0;
}
树套树
个人认为这才是真正二维线段树写法。
先建一棵普通线段树,然后将它的每个节点当成一个新的线段树的根(没错,如果再换一层就是三维线段树)。这是将矩形分为 $x$ 维和 $y$ 维分开考虑。这回先将其中一维确定,在另一位找真正位置,然后更新,而不是同时考虑两维的叶节点。
维护单点修改,区间最大值,最小值。
代码
注意:
建树时,是每个节点都是一棵线段树,所以所有节点都要遍历第二维。
建树时,第一维的左右节点要保留参数到第二维,判断更新的是叶子节点(赋值),还是“树”节点(更新最大值最小值)。
单点更改时,节点最大值最小值更新,要赋特殊值,重新遍历“更新”第二维,否则会更新不全,思考为什么?
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
long long r_r(){//快读
long long k=0,f=1;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c)){
k=(k<<1)+(k<<3)+(c^48);
c=getchar();
}
return k*f;
}
const int o_o=3e3+10;
int n,m,a_a[o_o][o_o],m_a,m_i;
struct po{
int m_a;//最大值
int m_i;//最小值
}t_r[o_o][o_o];
int l_l(int k){//左子树
return k<<1;
}
int r_r(int k){//右子树
return k<<1|1;
}
void u_y(int x,int y){//更新第二位
t_r[x][y].m_a=max(t_r[x][l_l(y)].m_a,t_r[x][r_r(y)].m_a);
t_r[x][y].m_i=min(t_r[x][l_l(y)].m_i,t_r[x][r_r(y)].m_i);
}
void u_x(int x,int y){//更新第二位
t_r[x][y].m_a=max(t_r[l_l(x)][y].m_a,t_r[r_r(x)][y].m_a);
t_r[x][y].m_i=min(t_r[l_l(x)][y].m_i,t_r[r_r(x)][y].m_i);
}
void b_2(int k_x,int x,int y,int k_y,int l,int r){
if(l!=r){//非叶子节点
int m_i=(l+r)>>1;
b_2(k_x,x,y,l_l(k_y),l,m_i);
b_2(k_x,x,y,r_r(k_y),m_i+1,r);
u_y(k_x,k_y);//更新第二维的最大最小值
}else {
if(x==y)t_r[k_x][k_y].m_a=t_r[k_x][k_y].m_i=a_a[x][l];//矩形小单位
else u_x(k_x,k_y);//更新第一维最大最小值
}
}
void b_1(int k_x,int l,int r){
if(l!=r){//非叶子节点
int m_i=(l+r)>>1;
b_1(l_l(k_x),l,m_i);
b_1(r_r(k_x),m_i+1,r);
}
//处理第二维
b_2(k_x,l,r,1,1,n);
}
void q_2(int k_y,int l,int r,int x,int y,int k_x){
if(l>=x&&r<=y){//在范围内,更新结果
m_a=max(m_a,t_r[k_x][k_y].m_a);
m_i=min(m_i,t_r[k_x][k_y].m_i);
}else {//找范围
int m_i=(l+r)>>1;
if(x<=m_i)q_2(l_l(k_y),l,m_i,x,y,k_x);
if(y>m_i)q_2(r_r(k_y),m_i+1,r,x,y,k_x);
}
}
void q_1(int k_x,int l,int r,int x,int y,int x_2,int y_2){//第一维查找
if(l>=x&&r<=y)q_2(1,1,n,x_2,y_2,k_x);//在范围内,找第二维
else{//找范围
int m_i=(l+r)>>1;
if(x<=m_i)q_1(l_l(k_x),l,m_i,x,y,x_2,y_2);
if(y>m_i)q_1(r_r(k_x),m_i+1,r,x,y,x_2,y_2);
}
}
void u_2(int k_y,int l,int r,int q,int v,int k_x){
if(l==r){//找到节点
if(v==-1)u_x(k_x,k_y);//更新第一维
else t_r[k_x][k_y].m_a=t_r[k_x][k_y].m_i=v;//更新节点
return;
}
//找范围
int m_i=(l+r)>>1;
if(m_i>=q)u_2(l_l(k_y),l,m_i,q,v,k_x);
else u_2(r_r(k_y),m_i+1,r,q,v,k_x);
u_y(k_x,k_y);
}
void u_1(int k_x,int l,int r,int x,int y,int v){
if(l==r)u_2(1,1,n,y,v,k_x);//找到要改的点的第一维
else{//找范围
int m_i=(l+r)>>1;
if(x<=m_i)u_1(l_l(k_x),l,m_i,x,y,v);
else u_1(r_r(k_x),m_i+1,r,x,y,v);
u_2(1,1,n,y,-1,k_x);//更新节点
}
}
int main(){
n=r_r();
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)a_a[i][j]=r_r();//读入初始矩阵
b_1(1,1,n);//建树
m=r_r();
for(int i=1;i<=m;++i){
char c=getchar();
while(c!='c'&&c!='q')c=getchar();
if(c=='q'){//区间最大,最小值
int x_1=r_r(),y_1=r_r(),x_2=r_r(),y_2=r_r();
m_i=0x3f3f3f3f;
m_a=0;
q_1(1,1,n,x_1,x_2,y_1,y_2);
printf("%d %d\n",m_a,m_i);
}else{//单点更新
int x=r_r(),y=r_r(),v=r_r();
u_1(1,1,n,x,y,v);
}
}
return 0;
}
虽然这道题暴力能过,但是建议用来练手。题解区有大佬用的四分树,但是时间要比树套树慢的多。
个人推荐第二种写法。