树结构算法总结(2) 线段树Segment Tree

Posted by SkyHigh on October 3, 2016

Previous

树结构算法总结(1) 二叉树的遍历

线段树Segment Tree

1.线段树基础

线段树的构造

什么是线段树?

线段树是一棵二叉树,他的每个节点包含了两个额外的属性start和end用于表示该节点所代表的区间。start和end都是整数,并按照如下的方式赋值:

  • 根节点的 startendbuild 方法所给出。
  • 对于节点 A 的左儿子,有 start=A.left, end=(A.left + A.right) / 2
  • 对于节点 A 的右儿子,有 start=(A.left + A.right) / 2 + 1, end=A.right
  • 如果 start 等于 end, 那么该节点是叶子节点,不再有左右儿子。

数据结构如下:

class SegmentTreeNode {
public:
    int start, end;
    SegmentTreeNode *left, *right;
    SegmentTreeNode(int start, int end) {
        this->start = start;
        this->end = end;
        this->left = this->right = NULL;
    }
}

实现一个 build 方法,接受 startend 作为参数, 然后构造一个代表区间 [start, end] 的线段树,返回这棵线段树的根。

说明
线段树(又称区间树), 是一种高级数据结构,他可以支持这样的一些操作:
查找给定的点包含在了哪些区间内
查找给定的区间包含了哪些点

见WIKI:
线段树
区间树

样例
比如给定start=1, end=6,对应的线段树为:


               [1,  6]
             /        \
      [1,  3]           [4,  6]
      /     \           /     \
   [1, 2]  [3,3]     [4, 5]   [6,6]
   /    \           /     \
[1,1]   [2,2]     [4,4]   [5,5]

同样,如果将线段树里的区间元素定义为一个数组A的索引index,并加入以下新的值max,表示数组某个区间内的最大值,则其数据结构定义如下:

class SegmentTreeNode {
public:
    int start, end, max;
    SegmentTreeNode *left, *right;
    SegmentTreeNode(int start, int end, int max) {
        this->start = start;
        this->end = end;
        this->max = max;
        this->left = this->right = NULL;
    }
}

构造方法

使用递归来完成线段树的构造。

不包含max的线段树的构造方法

# 不包含max的线段树的构造方法
/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end) {
 *         this->start = start, this->end = end;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
typedef SegmentTreeNode STN;
class Solution {
public:
    /**
     *@param start, end: Denote an segment / interval
     *@return: The root of Segment Tree
     */
    SegmentTreeNode * build(int start, int end) {
        // write your code here
        if(start>end)
            return NULL;
        STN* root = new STN(start, end);
        int mid = (start+end)/2;
        if(start<end){
            root->left = build(start, mid);
            root->right = build(mid+1, end);
        }
        return root;
    }
};

包含max的线段树的构造方法

# 包含max的线段树的构造方法
/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, max;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int max) {
 *         this->start = start;
 *         this->end = end;
 *         this->max = max;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
typedef SegmentTreeNode STN;
class Solution {
public:
    /**
     *@param A: a list of integer
     *@return: The root of Segment Tree
     */
    SegmentTreeNode * build(vector<int>& A) {
        // write your code here
        return build2(A, 0, A.size()-1);
    }
    SegmentTreeNode* build2(vector<int>& A, int l, int r){
        if(l>r)
            return NULL;
        if(l==r){
            STN* root = new STN(l, r, A[l]);
            return root;
        }
        STN* root = new STN(l, r, 0);
        int mid = (l+r)/2;
        root->left = build2(A, l, mid);
        root->right = build2(A, mid+1, r);
        root->max = max(root->left->max, root->right->max);
        return root;
    }
};

线段树的修改

一个数组对应一个线段树。现修改数组中的某个值,更新线段树。比如

                      [1, 4, max=3]
                    /                \
        [1, 2, max=2]                [3, 4, max=3]
       /              \             /             \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=3]

如果调用 modify(root, 2, 4), 返回:

                      [1, 4, max=4]
                    /                \
        [1, 2, max=4]                [3, 4, max=3]
       /              \             /             \
[1, 1, max=2], [2, 2, max=4], [3, 3, max=0], [4, 4, max=3]

或 调用 modify(root, 4, 0), 返回:

                      [1, 4, max=2]
                    /                \
        [1, 2, max=2]                [3, 4, max=0]
       /              \             /             \
[1, 1, max=2], [2, 2, max=1], [3, 3, max=0], [4, 4, max=0]

修改

/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, max;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int max) {
 *         this->start = start;
 *         this->end = end;
 *         this->max = max;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
typedef SegmentTreeNode STN;
class Solution {
public:
    /**
     *@param root, index, value: The root of segment tree and 
     *@ change the node's value with [index, index] to the new given value
     *@return: void
     */
    void modify(SegmentTreeNode *root, int index, int value) {
        // write your code here
        stack<STN*> s;
        while(root){
            s.push(root);
            int mid = (root->start+root->end)/2;
            if(mid>=index)
                root = root->left;
            else
                root = root->right;
        }
        if(!s.empty()){
            s.top()->max = value;
            s.pop();
        }
        while(!s.empty()){
            STN* cur = s.top();
            s.pop();
            cur->max = max(cur->left->max, cur->right->max);
        }
    }
};

线段树的查询

查询某个数组区间内的最大值,或者某个数组区间内的元素个数。

查询某个数组区间内的最大值

对于数组 [1, 4, 2, 3], 对应的线段树为:

                  [0, 3, max=4]
                 /             \
          [0,1,max=4]        [2,3,max=3]
          /         \        /         \
   [0,0,max=1] [1,1,max=4] [2,2,max=2], [3,3,max=3]

query(root, 1, 1), return 4

query(root, 1, 2), return 4

query(root, 2, 3), return 3

query(root, 0, 2), return 4

/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, max;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int max) {
 *         this->start = start;
 *         this->end = end;
 *         this->max = max;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
class Solution {
public:
    /**
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The maximum number in the interval [start, end]
     */
    int query(SegmentTreeNode *root, int start, int end) {
        // write your code here
        if(!root)
            return 0;
        if(!root->left&&!root->right)
            return root->max;
        if(start<=root->start&&end>=root->end)
            return root->max;
        int mid = (root->start+root->end)/2;
        if(start<=mid&&end>mid){
            return max(query(root->left, start, mid), query(root->right, mid+1, end));
        }
        else if(start<=mid&&end<=mid){
            return query(root->left, start, end);
        }
        else
            return query(root->right, start, end);
    }
};

查询某个数组区间内的元素个数

对于数组 [0, 空,2, 3], 对应的线段树为:

                     [0, 3, count=3]
                     /             \
          [0,1,count=1]             [2,3,count=2]
          /         \               /            \
   [0,0,count=1] [1,1,count=0] [2,2,count=1], [3,3,count=1]

query(1, 1), return 0

query(1, 2), return 1

query(2, 3), return 2

query(0, 2), return 2

/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, count;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int count) {
 *         this->start = start;
 *         this->end = end;
 *         this->count = count;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
class Solution {
public:
    /**
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The count number in the interval [start, end] 
     */
    int query(SegmentTreeNode *root, int start, int end) {
        // write your code here
        if(!root||start>end||end<root->start||start>root->end)
            return 0;
        if(!root->left&&!root->right)
            return root->count;
        if(start<=root->start&&end>=root->end)
            return root->count;
        int mid = (root->start+root->end)/2;
        if(start<=mid&&end>mid){
            return query(root->left, start, mid)+query(root->right, mid+1, end);
        }
        else if(start<=mid&&end<=mid){
            return query(root->left, start, end);
        }
        else
            return query(root->right, start, end);
    }
};

2.线段树的应用

后续……