Previous
线段树Segment Tree
1.线段树基础
线段树的构造
什么是线段树?
线段树是一棵二叉树,他的每个节点包含了两个额外的属性start和end用于表示该节点所代表的区间。start和end都是整数,并按照如下的方式赋值:
- 根节点的 start 和 end 由 build 方法所给出。
- 对于节点 A 的左儿子,有 start=A.left, end=(A.left + A.right) / 2。
- 对于节点 A 的右儿子,有 start=(A.left + A.right) / 2 + 1, end=A.right。
- 如果 start 等于 end, 那么该节点是叶子节点,不再有左右儿子。
数据结构如下:
class SegmentTreeNode {
public:
int start, end;
SegmentTreeNode *left, *right;
SegmentTreeNode(int start, int end) {
this->start = start;
this->end = end;
this->left = this->right = NULL;
}
}
实现一个 build 方法,接受 start 和 end 作为参数, 然后构造一个代表区间 [start, end] 的线段树,返回这棵线段树的根。
说明
线段树(又称区间树), 是一种高级数据结构,他可以支持这样的一些操作:
查找给定的点包含在了哪些区间内
查找给定的区间包含了哪些点
样例
比如给定start=1, end=6,对应的线段树为:
[1, 6]
/ \
[1, 3] [4, 6]
/ \ / \
[1, 2] [3,3] [4, 5] [6,6]
/ \ / \
[1,1] [2,2] [4,4] [5,5]
同样,如果将线段树里的区间元素定义为一个数组A的索引index,并加入以下新的值max,表示数组某个区间内的最大值,则其数据结构定义如下:
class SegmentTreeNode {
public:
int start, end, max;
SegmentTreeNode *left, *right;
SegmentTreeNode(int start, int end, int max) {
this->start = start;
this->end = end;
this->max = max;
this->left = this->right = NULL;
}
}
构造方法
使用递归来完成线段树的构造。
不包含max的线段树的构造方法
# 不包含max的线段树的构造方法
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end) {
* this->start = start, this->end = end;
* this->left = this->right = NULL;
* }
* }
*/
typedef SegmentTreeNode STN;
class Solution {
public:
/**
*@param start, end: Denote an segment / interval
*@return: The root of Segment Tree
*/
SegmentTreeNode * build(int start, int end) {
// write your code here
if(start>end)
return NULL;
STN* root = new STN(start, end);
int mid = (start+end)/2;
if(start<end){
root->left = build(start, mid);
root->right = build(mid+1, end);
}
return root;
}
};
包含max的线段树的构造方法
# 包含max的线段树的构造方法
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
typedef SegmentTreeNode STN;
class Solution {
public:
/**
*@param A: a list of integer
*@return: The root of Segment Tree
*/
SegmentTreeNode * build(vector<int>& A) {
// write your code here
return build2(A, 0, A.size()-1);
}
SegmentTreeNode* build2(vector<int>& A, int l, int r){
if(l>r)
return NULL;
if(l==r){
STN* root = new STN(l, r, A[l]);
return root;
}
STN* root = new STN(l, r, 0);
int mid = (l+r)/2;
root->left = build2(A, l, mid);
root->right = build2(A, mid+1, r);
root->max = max(root->left->max, root->right->max);
return root;
}
};
线段树的修改
一个数组对应一个线段树。现修改数组中的某个值,更新线段树。比如
[1, 4, max=3]
/ \
[1, 2, max=2] [3, 4, max=3]
/ \ / \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=3]
如果调用 modify(root, 2, 4), 返回:
[1, 4, max=4]
/ \
[1, 2, max=4] [3, 4, max=3]
/ \ / \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]
或 调用 modify(root, 4, 0), 返回:
[1, 4, max=2]
/ \
[1, 2, max=2] [3, 4, max=0]
/ \ / \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=0]
修改
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
typedef SegmentTreeNode STN;
class Solution {
public:
/**
*@param root, index, value: The root of segment tree and
*@ change the node's value with [index, index] to the new given value
*@return: void
*/
void modify(SegmentTreeNode *root, int index, int value) {
// write your code here
stack<STN*> s;
while(root){
s.push(root);
int mid = (root->start+root->end)/2;
if(mid>=index)
root = root->left;
else
root = root->right;
}
if(!s.empty()){
s.top()->max = value;
s.pop();
}
while(!s.empty()){
STN* cur = s.top();
s.pop();
cur->max = max(cur->left->max, cur->right->max);
}
}
};
线段树的查询
查询某个数组区间内的最大值,或者某个数组区间内的元素个数。
查询某个数组区间内的最大值
对于数组 [1, 4, 2, 3]
, 对应的线段树为:
[0, 3, max=4]
/ \
[0,1,max=4] [2,3,max=3]
/ \ / \
[0,0,max=1] [1,1,max=4] [2,2,max=2], [3,3,max=3]
query(root, 1, 1)
, return 4
query(root, 1, 2)
, return 4
query(root, 2, 3)
, return 3
query(root, 0, 2)
, return 4
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
*@param root, start, end: The root of segment tree and
* an segment / interval
*@return: The maximum number in the interval [start, end]
*/
int query(SegmentTreeNode *root, int start, int end) {
// write your code here
if(!root)
return 0;
if(!root->left&&!root->right)
return root->max;
if(start<=root->start&&end>=root->end)
return root->max;
int mid = (root->start+root->end)/2;
if(start<=mid&&end>mid){
return max(query(root->left, start, mid), query(root->right, mid+1, end));
}
else if(start<=mid&&end<=mid){
return query(root->left, start, end);
}
else
return query(root->right, start, end);
}
};
查询某个数组区间内的元素个数
对于数组 [0, 空,2, 3]
, 对应的线段树为:
[0, 3, count=3]
/ \
[0,1,count=1] [2,3,count=2]
/ \ / \
[0,0,count=1] [1,1,count=0] [2,2,count=1], [3,3,count=1]
query(1, 1)
, return 0
query(1, 2)
, return 1
query(2, 3)
, return 2
query(0, 2)
, return 2
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, count;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int count) {
* this->start = start;
* this->end = end;
* this->count = count;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
*@param root, start, end: The root of segment tree and
* an segment / interval
*@return: The count number in the interval [start, end]
*/
int query(SegmentTreeNode *root, int start, int end) {
// write your code here
if(!root||start>end||end<root->start||start>root->end)
return 0;
if(!root->left&&!root->right)
return root->count;
if(start<=root->start&&end>=root->end)
return root->count;
int mid = (root->start+root->end)/2;
if(start<=mid&&end>mid){
return query(root->left, start, mid)+query(root->right, mid+1, end);
}
else if(start<=mid&&end<=mid){
return query(root->left, start, end);
}
else
return query(root->right, start, end);
}
};
2.线段树的应用
后续……