树链剖分介绍
常见题目,已知一棵包含 $N$ 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
- $1$ $x$ $y$ $z$,表示将树从 $x$ 到 $y$ 结点最短路径上所有节点的值都加上 $z$。
- $2$ $x$ $y$,表示求树从 $x$ 到 $y$ 结点最短路径上所有节点的值之和。
- $3$ $x$ $z$,表示将以 $x$ 为根节点的子树内所有节点值都加上 $z$。
- $4$ $x$ 表示求以 $x$ 为根节点的子树内所有节点值之和
操作分析
操作 $1$,$2$非常像线段树的区间和,区间查询操作。而操作 $3$,$4$用线段树也可以实现,但是肯定要给他们一种特殊的存储和特殊的查询。
特殊的存储——剖分
首先介绍几个名词:重儿子,重边,轻儿子,轻链,轻边,重链。
- 重儿子:每个节点所有子节点中,儿子最多的节点。
- 重边:节点和重儿子连的边。
- 轻儿子:除了重儿子,其他子节点。
- 轻边:节点和轻儿子连的边。
- 重链:连续的重边连成的链。
- 轻链:连续的轻边连成的链。
存储节点状态
关于节点,我们要存:父节点,节点大小(重量),重儿子,深度,链头,新编号。
- 链头:重链中,最靠近根节点的节点。
- 新编号:我们要用线段树维护,所以要根据重链重新个所有节点附一个新的编号,将重链上的节点都变成连续的编号,可以方便区间操作。
我们可以跑两遍深搜,来处理所有存储信息。
void d_1(int k,int f,int d_p){
q_q[k].d_p=d_p;//节点深度
q_q[k].f=f;//存父节点
q_q[k].s_z=1;//节点大小(重量)初值
int m_a=-1;//找重儿子
for(int i=h_d[k];i;i=p_p[i].n_t){
int v=p_p[i].v;
if(v==f)continue;
d_1(v,k,d_p+1);//遍历儿子
q_q[k].s_z+=q_q[v].s_z;//更新当前节点重量
if(m_a<q_q[v].s_z)m_a=q_q[v].s_z,q_q[k].b_s=v;//更新重儿子
}
}
- 节点深度:方便后面找操作 $1$,$2$ 最近公共祖先。
要先找到所有的重儿子,才能找到重链。
void d_2(int k,int t_p){
q_q[k].t_p=t_p;//链头
q_q[k].i_d=++x_n;//赋新节点编号
n_w[x_n]=a_a[k];//用新节点存储当前节点的价值
if(!q_q[k].b_s)return ;//没有种儿子,说明没有儿子
d_2(q_q[k].b_s,t_p);//先遍历重儿子,形成重链
for(int i=h_d[k];i;i=p_p[i].n_t){
int v=p_p[i].v;
if(v==q_q[k].f||v==q_q[k].b_s)continue;//重儿子已经遍历过
d_2(v,v);//轻儿子是新重链的链头
}
}
建树
点的基础信息存完后就可以开始根据新节点编号建树。
void b_t(int l,int r,int k){//建树
if(l==r){
t_t[k]=n_w[l];//用新编号存储节点
if(t_t[k]>p)t_t[k]=n_w[l]%o_o;
return ;
}
int m_i=(l+r)>>1;
b_t(l,m_i,l_l(k));
b_t(m_i+1,r,r_r(k));
u_p(k);//更新节点
}
经典的线段树建树操作。
处理操作
操作中的的线段树经典操作太长就不放了。
操作一
两点的路径上的加处理,我们可以根据提前存的链来处理。因为每个重链的节点编号是连续的,所以可以不断遍历这些重链,再用线段数的区间加操作就行。
void a_p(int l,int r,int v){//区间加
v%=p;
while(q_q[l].t_p!=q_q[r].t_p){//链头不相等(不在同一个链上)
if(q_q[q_q[l].t_p].d_p<q_q[q_q[r].t_p].d_p)swap(l,r);//深度更深的链先处理区间加
u_d(1,1,n,q_q[q_q[l].t_p].i_d,q_q[l].i_d,v);//处理范围:链头编号到当前节点编号
l=q_q[q_q[l].t_p].f;//更新范围,变成链头的父亲,进入新链
}
if(q_q[l].d_p>q_q[r].d_p)swap(l,r);
u_d(1,1,n,q_q[l].i_d,q_q[r].i_d,v);//同一个链上直接处理编号就行
}
操作二
两点上的历经求和处理,和加处理相似,通过重链遍历,直接套线段树的区间求和就行。
int s_p(int l,int r){//区间和
int a_s=0;//结果初始化
while(q_q[l].t_p!=q_q[r].t_p){//链头不相等(不在同一个链上)
if(q_q[q_q[l].t_p].d_p<q_q[q_q[r].t_p].d_p)swap(l,r);//深度更深的链先处理区间和
r_s=0;//结果赋初值
f_i(1,1,n,q_q[q_q[l].t_p].i_d,q_q[l].i_d);//处理范围:链头编号到当前节点编号
a_s=(a_s+r_s)%p;//累加重链结果
l=q_q[q_q[l].t_p].f;//更新范围,变成链头的父亲,进入新链
}
if(q_q[l].i_d>q_q[r].i_d)swap(l,r);
r_s=0;//赋初值
f_i(1,1,n,q_q[l].i_d,q_q[r].i_d);//同一个链上直接处理编号
a_s=(a_s+r_s)%p;//更新
return a_s;
}
操作三
子树全体增值,先放代码再解释。
void a_z(int k,int v){
//由于按重儿子遍历,且先遍历重链,所以编号是先跟着重链遍历到树底,在回溯的顺序赋予轻节点编号。
//所以每个节点的子树的编号是连续的,可以用它的编号到它的编号加它的大小减一的区间进行操作。
u_d(1,1,n,q_q[k].i_d,q_q[k].i_d+q_q[k].s_z-1,v);
}
举个例子:
1
/ \
/ \
/ \
2 3
/ \
/ \
4 5
/ / \
6 7 8
我们给它剖一下。
1
新编号: 1
/ \
/ \
/ \
2 3
新编号: 8 2
/ \
/ \
4 5
新编号: 6 3
/ / \
6 7 8
新编号: 7 4 5
我们不难发现没个子数的所有编号都是连续的。
比如:节点 $3$ 的子树编号 $[3,5]$,节点 $2$ 的子树编号 $[2,7]$。
原因在于存点的编号是按重链优先的,所以编号一定是先跑到树底,才从树底慢慢往回跑赋的值,所以树的节点都是连续的。
操作四
和操作三非常相似,从区间加变成了区间求和。
int s_z(int k){
r_s=0;//结果赋初值
//和子树加同理
f_i(1,1,n,q_q[k].i_d,q_q[k].i_d+q_q[k].s_z-1);
return r_s;
}
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int r_r(){//快读
int f=1,x=0;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c)){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
const int o_o=1e6+10;
int n,m,r,p;//题目描述:节点数量,操作次数,根节点,取模
int a_a[o_o];//节点初值
struct po{//链式前向星存边
int v;
int n_t;
}p_p[o_o];
int h_d[o_o],x_p;
void a_d(int u,int v){//链式前向星存边
p_p[++x_p].v=v;
p_p[x_p].n_t=h_d[u];
h_d[u]=x_p;
}
struct pp{
int s_z;//节点价值(重量)
int f;//父节点
int d_p;//深度
int b_s;//重儿子
int t_p;//链头
int i_d;//新编号
}q_q[o_o];
int n_w[o_o],x_n;//存贮新编号节点,赋新编号
int t_t[o_o],l_a[o_o],r_s=0;//树,懒标记,区间查询结果
void d_1(int k,int f,int d_p){
q_q[k].d_p=d_p;//节点深度
q_q[k].f=f;//存父节点
q_q[k].s_z=1;//节点大小(重量)初值
int m_a=-1;//找重儿子
for(int i=h_d[k];i;i=p_p[i].n_t){
int v=p_p[i].v;
if(v==f)continue;
d_1(v,k,d_p+1);//遍历儿子
q_q[k].s_z+=q_q[v].s_z;//更新当前节点重量
if(m_a<q_q[v].s_z)m_a=q_q[v].s_z,q_q[k].b_s=v;//更新重儿子
}
}
void d_2(int k,int t_p){
q_q[k].t_p=t_p;//链头
q_q[k].i_d=++x_n;//赋新节点编号
n_w[x_n]=a_a[k];//用新节点存储当前节点的价值
if(!q_q[k].b_s)return ;//没有种儿子,说明没有儿子
d_2(q_q[k].b_s,t_p);//先遍历重儿子,形成重链
for(int i=h_d[k];i;i=p_p[i].n_t){
int v=p_p[i].v;
if(v==q_q[k].f||v==q_q[k].b_s)continue;//重儿子已经遍历过
d_2(v,v);//轻儿子是新重链的链头
}
}
int l_l(int k){//左儿子
return k<<1;
}
int r_r(int k){//右儿子
return k<<1|1;
}
void u_p(int k){//更新节点
t_t[k]=(t_t[l_l(k)]+t_t[r_r(k)])%p;
}
void p_d(int k,int l_n){//下传懒标记,更新值
l_a[l_l(k)]+=l_a[k];
l_a[r_r(k)]+=l_a[k];
t_t[l_l(k)]+=l_a[k]*(l_n-(l_n>>1));
t_t[r_r(k)]+=l_a[k]*(l_n>>1);
t_t[l_l(k)]%=p;
t_t[r_r(k)]%=p;
l_a[k]=0;
}
void u_d(int k,int l,int r,int x,int y,int v){//区间加
if(x<=l&&r<=y){
l_a[k]+=v;//懒标记
t_t[k]+=v*(r-l+1);//处理当前节点
return ;
}
int m_i=(l+r)>>1;
if(l_a[k])p_d(k,r-l+1);//处理懒标记
if(x<=m_i)u_d(l_l(k),l,m_i,x,y,v);//处理左边
if(y>m_i)u_d(r_r(k),m_i+1,r,x,y,v);//处理右边
u_p(k);//更新节点
}
void f_i(int k,int l,int r,int x,int y){//区间查询
if(x<=l&&r<=y){
r_s+=t_t[k];
r_s%=p;
return ;
}
int m_i=(l+r)>>1;
if(l_a[k])p_d(k,r-l+1);//将懒标记释放
if(x<=m_i)f_i(l_l(k),l,m_i,x,y);//处理左边
if(y>m_i)f_i(r_r(k),m_i+1,r,x,y);//处理右边
}
void b_t(int l,int r,int k){//建树
if(l==r){
t_t[k]=n_w[l];//用新编号存储节点
if(t_t[k]>p)t_t[k]=n_w[l]%o_o;
return ;
}
int m_i=(l+r)>>1;
b_t(l,m_i,l_l(k));
b_t(m_i+1,r,r_r(k));
u_p(k);//更新节点
}
void a_p(int l,int r,int v){//区间加
v%=p;
while(q_q[l].t_p!=q_q[r].t_p){//链头不相等(不在同一个链上)
if(q_q[q_q[l].t_p].d_p<q_q[q_q[r].t_p].d_p)swap(l,r);//深度更深的链先处理区间加
u_d(1,1,n,q_q[q_q[l].t_p].i_d,q_q[l].i_d,v);//处理范围:链头编号到当前节点编号
l=q_q[q_q[l].t_p].f;//更新范围,变成链头的父亲,进入新链
}
if(q_q[l].d_p>q_q[r].d_p)swap(l,r);
u_d(1,1,n,q_q[l].i_d,q_q[r].i_d,v);//同一个链上直接处理编号就行
}
int s_p(int l,int r){//区间和
int a_s=0;//结果初始化
while(q_q[l].t_p!=q_q[r].t_p){//链头不相等(不在同一个链上)
if(q_q[q_q[l].t_p].d_p<q_q[q_q[r].t_p].d_p)swap(l,r);//深度更深的链先处理区间和
r_s=0;//结果赋初值
f_i(1,1,n,q_q[q_q[l].t_p].i_d,q_q[l].i_d);//处理范围:链头编号到当前节点编号
a_s=(a_s+r_s)%p;//累加重链结果
l=q_q[q_q[l].t_p].f;//更新范围,变成链头的父亲,进入新链
}
if(q_q[l].i_d>q_q[r].i_d)swap(l,r);
r_s=0;//赋初值
f_i(1,1,n,q_q[l].i_d,q_q[r].i_d);//同一个链上直接处理编号
a_s=(a_s+r_s)%p;//更新
return a_s;
}
void a_z(int k,int v){
//由于按重儿子遍历,且先遍历重链,所以编号是先跟着重链遍历到树底,在回溯的顺序赋予轻节点编号。
//所以每个节点的子树的编号是连续的,可以用它的编号到它的编号加它的大小减一的区间进行操作。
u_d(1,1,n,q_q[k].i_d,q_q[k].i_d+q_q[k].s_z-1,v);
}
int s_z(int k){
r_s=0;//结果赋初值
//和子树加同理
f_i(1,1,n,q_q[k].i_d,q_q[k].i_d+q_q[k].s_z-1);
return r_s;
}
int main(){
n=r_r(),m=r_r(),r=r_r(),p=r_r();
for(int i=1;i<=n;i++)a_a[i]=r_r();//存节点初值
for(int i=1;i<n;i++){
int a=r_r(),b=r_r();
a_d(a,b);//连边
a_d(b,a);//连边
}
d_1(r,0,1);//从根节点开始处理点
d_2(r,r);//从根节点开始处理点
b_t(1,n,1);//建树
while(m--){
int op=r_r();
if(op==1){//树上两点路径上加处理
int x=r_r(),y=r_r(),z=r_r();
a_p(x,y,z);
}else if(op==2){//树上两点路径上求和处理
int x=r_r(),y=r_r();
printf("%d\n",s_p(x,y));
}else if(op==3){//节点子树加
int x=r_r(),z=r_r();
a_z(x,z);
}else {//节点子树求和
int x=r_r();
printf("%d\n",s_z(x));
}
}
return 0;
}