Skip to content

Segment Tree

Segment Tree (無區間更新)

  • 只支援區間查詢,單點更新。

  • 把任意區間用 \(O(\log{n})\) 個區間表示,線段樹的每個節點記錄對應區間的信息。

  • 詢問:把詢問區間拆分成 \(O(\log{n})\) 個區間,對應著線段樹的 \(O(\log{n})\) 個節點,把這 \(O(\log{n})\) 個節點的信息合並,即為答案。

  • 單點更新:有 \(O(\log{n})\) 個區間包含被修改的位置,需要更新 \(O(\log{n})\) 個節點的信息。

  • 若是葉子節點有 n 個節點,則總共需要開闢的陣列空間為 \(2^{\lceil \log{n} \rceil + 1} - 1\):

    可以知道完美二元樹時,此時樹的高度 \(h = \lceil \log{n} \rceil\),總共的節點個數為:

    \[2^0 + 2^1 + 2^2 + .... + 2^h = \frac{1(1 - 2^{h + 1})}{1 - 2} = 2^{h + 1} - 1\]

    偷懶寫法可以直接宣告 \(4n\) 的空間。

// 模板来源 https://leetcode.cn/circle/discuss/mOr1u6/
// 线段树有两个下标,一个是线段树节点的下标,另一个是线段树维护的区间的下标
// 节点的下标:从 1 开始,如果你想改成从 0 开始,需要把左右儿子下标分别改成 node*2+1 和 node*2+2
// 区间的下标:从 0 开始
template<typename T>
class SegmentTree {
    // 注:也可以去掉 template<typename T>,改在这里定义 T
    // using T = pair<int, int>;

    int n;
    vector<T> tree;

    // 合并两个 val
    T merge_val(T a, T b) const {
        return max(a, b); // **根据题目修改**
    }

    // 合并左右儿子的 val 到当前节点的 val
    void maintain(int node) {
        tree[node] = merge_val(tree[node * 2], tree[node * 2 + 1]);
    }

    // 用 a 初始化线段树
    // 时间复杂度 O(n)
    void build(const vector<T>& a, int node, int l, int r) {
        if (l == r) { // 叶子
            tree[node] = a[l]; // 初始化叶节点的值
            return;
        }
        int m = (l + r) / 2;
        build(a, node * 2, l, m); // 初始化左子树
        build(a, node * 2 + 1, m + 1, r); // 初始化右子树
        maintain(node);
    }

    //要注意這是單點更新
    void update(int node, int l, int r, int i, T val) {
        if (l == r) { // 叶子(到达目标)
            // 如果想直接替换的话,可以写 tree[node] = val
            tree[node] = merge_val(tree[node], val);
            return;
        }
        int m = (l + r) / 2;
        if (i <= m) { // i 在左子树
            update(node * 2, l, m, i, val);
        } else { // i 在右子树
            update(node * 2 + 1, m + 1, r, i, val);
        }
        maintain(node);
    }

    T query(int node, int l, int r, int ql, int qr) const {
        if (ql <= l && r <= qr) { // 当前子树完全在 [ql, qr] 内
            return tree[node];
        }
        int m = (l + r) / 2;
        if (qr <= m) { // [ql, qr] 在左子树
            return query(node * 2, l, m, ql, qr);
        }
        if (ql > m) { // [ql, qr] 在右子树
            return query(node * 2 + 1, m + 1, r, ql, qr);
        }
        T l_res = query(node * 2, l, m, ql, qr);
        T r_res = query(node * 2 + 1, m + 1, r, ql, qr);
        return merge_val(l_res, r_res);
    }

public:
    // 线段树维护一个长为 n 的数组(下标从 0 到 n-1),元素初始值为 init_val
    SegmentTree(int n, T init_val) : SegmentTree(vector<T>(n, init_val)) {}

    // 线段树维护数组 a
    SegmentTree(const vector<T>& a) : n(a.size()), tree(2 << bit_width(a.size() - 1)) {
        build(a, 1, 0, n - 1);
    }

    // 更新 a[i] 为 merge_val(a[i], val)
    // 时间复杂度 O(log n)
    void update(int i, T val) {
        update(1, 0, n - 1, i, val);
    }

    // 返回用 merge_val 合并所有 a[i] 的计算结果,其中 i 在闭区间 [ql, qr] 中
    // 时间复杂度 O(log n)
    T query(int ql, int qr) const {
        return query(1, 0, n - 1, ql, qr);
    }

    // 获取 a[i] 的值
    // 时间复杂度 O(log n)
    T get(int i) const {
        return query(1, 0, n - 1, i, i);
    }
};

Lazy Segment Tree (區間修改)