树链剖分

本文介绍树链剖分算法

例题

按照惯例 先介绍一道题目
HDOJ3966
抽象题意:题目给了一颗树,每个节点都有一个权值,给出以下三种操作:

  1. D l r v 表示从l节点到r节点的简单路径上的所有节点同时减少v的权值
  2. I l r v 表示从l节点到r节点的简单路径上的所有节点同时增加v的权值
  3. Q p 查询p节点上的权值

    朴素做法

    比较朴素的做法是用一个数组维护每一个节点的权值,更新路径上节点信息的时候从l节点向上更新至l和r节点的lca,从r节点向上更新至l和r节点的lca
    考虑复杂度,每次更新路径的复杂度最坏是O(n)的,n是节点总个数,如果要更新m次,复杂度就是O(n*m),这个复杂度是很差的,必然会TLE
    有什么办法优化呢?

    线性数组的延申

    先考虑线性数组,通常区段更新和单点查询我们会采用树状数组或者线段树的数据结构来进行维护,复杂度可以降低到对数级别
    那么树上怎么套树状数组(线段树)呢
    首先我们得将树展开成链
    通常的办法是树链剖分

    树链剖分

    naive的做法

    把树展开成链是很简单的 一遍dfs就能搞定 在路径分岔口注意申请多条链就行了

    最优复杂度的剖分

    介绍一下轻重边剖分策略
    为了方便叙述 记sz[i]为以i节点为根的子树节点个数
    轻边性质:对于轻边u->v,有sz[v]<=sz[u/2]
    重边反之
    那么对于整颗树而言 从树根到叶子节点所经过的轻边/重路径不会超过$\log(n)$条
    证明是显然的,因为构造一颗平衡树的高度总是为$\log(n)$
    那么 一条边(u->v)为重边的条件即为对于节点u的所有节点vi都有sz[v]>=sz[vi] (通俗的讲就是子树节点最多的儿子)

    实现

    分两步dfs
    第一步dfs:记录一下所有节点的深度 父亲 子树大小 子树节点最多的儿子(重儿子)
    例题代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    //v->当前考虑的点 u->v的父亲节点 dep->深度
    //g->邻接表
    //fa->父亲数组
    void dfs1(int v,int u,int dep){
    depth[v]=dep;
    fa[v]=u;
    sz[v]=1;
    for(int i=0;i<g[v].size();i++){
    int nxt=g[v][i];
    if(nxt==u) continue;
    dfs1(nxt,v,dep+1);
    sz[v]+=sz[nxt];
    //maxchild数组一开始初始化为-1
    if(maxchild[v]==-1||maxchild[v]<sz[nxt]) maxchild[v]=nxt;
    }
    }

第二步dfs:记录一下所有节点的dfs序 与此同时 我们要做好树链剖分的工作:

  1. 如果正在考虑的节点v 满足是树根节点或者不是它父亲节点的重儿子 那么新建一条链 把v当作这条链的链头并做记录
  2. 若不然 把正在考虑的节点v连到它父亲节点所在链的尾部
    ps:有时候并不需要记录整条链,有时候需要记录其他信息,对于前文的例题来说 我们记录一下每一个节点v所在链的链头top[v]
    在第二步dfs中 如果需要使用树状数组等维护连续数据的数据结构 我们需要保持同一条链的dfs序是连续的 通过第二步dfs的[1]可以知道 只需要在进入儿子节点的栈堆帧之前 先访问重儿子即可(不要重复访问重儿子)
    例题代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    //v->当前考虑的点 ts->dfs时间戳,用来记录dfs序
    //chaincnt->链计数器
    void dfs2(int v){
    dfn[v]=++ts;
    if(fa[v]==-1||v!=maxchild[fa[v]]){
    chaincnt++;
    top[v]=v;
    }
    else{
    top[v]=top[fa[v]];
    }
    if(maxchild[v]==-1) return;
    dfs2(maxchild[v]);
    for(int i=0;i<g[v].size();i++){
    int u=g[v][i];
    if(u!=maxchild[v]&&u!=fa[v]) dfs2(u);
    }
    }

好了我们剖分完了 此时的dfn数组中保存的是每一个节点的dfs序 如果节点存在相同的链中 那么他们的dfn值是连续的

区段更新

回到题目 对于操作1和2 我们要更新简单路径上的节点权值
分两种情况:

  1. l r在同一条链上
  2. l r在不同链上
    假设我们用树状数组来维护数据
    对于情况1 那好办 由于同一条链上的dfn值是连续的 那么只要做操作 add(dfn[l],v),add(dfn[r]+1,-v) 即可 注意这里的l得在r的上面(也就是说 更靠近链头)
    对于操作2 有很多办法转化成操作1 有人喜欢先求lca 其实不用这么麻烦:

    begin
    更新节点所在链链头深度比较大的节点不妨设l
    做操作 add(dfn[top[l]],v),add(dfn[l]+1,-v)
    把l跳到fa[top[l]]
    goto begin

这样就把情况2规约成了许多情况1的叠加(建议画个图更容易理解)
最后在注意一下 所有在树状数组上的操作下标需要使用每个节点的dfn值

例题代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include<bits/stdc++.h>

using namespace std;
const int N=50000+10;

int n,m,p;
vector<int> g[N];
int maxchild[N],dfn[N],sz[N];
int fa[N];
int depth[N];
int top[N];
int val[N];
int chaincnt,ts;
int c[N];

//--------树状数组代码---------
inline int lowbit(int x){
return x&(-x);
}

inline void add(int pos,int v){
for(;pos<=n;pos+=lowbit(pos)){
c[pos]+=v;
}
}

inline int ask(int pos){
int ret=0;
for(;pos>0;pos-=lowbit(pos)) ret+=c[pos];
return ret;
}
//----------------------------

//初始化
void init(){
for(int i=1;i<=n;i++){
g[i].clear();
}
memset(maxchild,-1,sizeof(maxchild));
memset(sz,0,sizeof(sz));
memset(top,-1,sizeof(top));
memset(fa,-1,sizeof(fa));
memset(c,0,sizeof(c));
chaincnt=ts=0;
}

//第一次dfs
void dfs1(int v,int u,int dep){
depth[v]=dep;
fa[v]=u;
sz[v]=1;
for(int i=0;i<g[v].size();i++){
int nxt=g[v][i];
if(nxt==u) continue;
dfs1(nxt,v,dep+1);
sz[v]+=sz[nxt];
if(maxchild[v]==-1||maxchild[v]<sz[nxt]) maxchild[v]=nxt;
}
}

//第二次dfs
void dfs2(int v){
dfn[v]=++ts;
if(fa[v]==-1||v!=maxchild[fa[v]]){
chaincnt++;
top[v]=v;
}
else{
top[v]=top[fa[v]];
}
if(maxchild[v]==-1) return;
dfs2(maxchild[v]);
for(int i=0;i<g[v].size();i++){
int u=g[v][i];
if(u!=maxchild[v]&&u!=fa[v]) dfs2(u);
}
}

int main(){
#ifndef ONLINE_JUDGE
freopen("input.in","r",stdin);
freopen("output.o","w+",stdout);
#endif
while(~scanf("%d %d %d",&n,&m,&p)){
init();
for(int i=1;i<=n;i++) scanf("%d",val+i);
for(int i=1;i<=m;i++){
int from,to;
scanf("%d %d",&from,&to);
g[from].push_back(to);
g[to].push_back(from);
}
dfs1(1,-1,1);
dfs2(1);
for(int i=1;i<=n;i++){
add(dfn[i],val[i]);
add(dfn[i]+1,-val[i]);
}
while(p--){
char op;
cin>>op;
//数据量挺大 为什么用cin呢...因为...
//后台数据有点问题 getchar处理不好回车
//这也是许多人RE/TLE的主要原因
if(op=='Q'){
int poss;
scanf("%d",&poss);
poss=dfn[poss];
printf("%d\n",ask(poss));
}
else{
//区段更新
int l,r,inc;
scanf("%d %d %d",&l,&r,&inc);
if(op=='D') inc=-inc;
int f1=top[l];
int f2=top[r];
while(f1!=f2){
if(depth[f1]<depth[f2]){
swap(f1,f2);
swap(l,r);
}
add(dfn[f1],inc);
add(dfn[l]+1,-inc);
l=fa[f1];
f1=top[l];
}
if(depth[l]>depth[r]) swap(l,r);
add(dfn[l],inc);
add(dfn[r]+1,-inc);
}
}
}
return 0;
}