Segment Tree | 线段树

Albert Wang / 2023-03-03 / 500 Words/has been Read   Times


背景 #

这篇博客主要介绍线段树的相关知识。假如我们需要对区间进行操作,比如需要多次查询某一个区间的和,或者多次查询某一个区间的最大值,或者多次在某些区间上加一个值。很明显,我们操作一次的时间复杂度是O(n),操作 m 次的时间复杂度是 O(mn)。

对于$O(n^2)$ 这样的时间复杂度来说,我们的数据量一般不能超过$10^5$,但是如果我们有一种手段把它的时间复杂度变成nlogn级别,那效率就会提升很多。

线段树 #

线段树就是一种很好的解决方案。它的图形化表达如下图所示;

img

假如我们想要查询某一个区间的和,首先我们知道总的区间是[1, 10],然后我们把他进行折半,分成[1, 5] 和 [6, 10] 两个区间。然后对这两个子区间再折半,不断地递归下去,知道这个区间已经只有一个元素,也就是对应着叶子节点。

这时对于叶子节点我们很容易知道它的和就是这个元素本身,因为它只有一个元素。对于任意一个非叶子节点来说,它的区间和就是先对它的子结点求和,再把每一个子结点的和相加。

代码实现 #

不难看出,线段树其实是一棵完全二叉树,所以我们直接用一个一维数组就可以保存树节点信息,类似于堆的存储方式。同样的,对于下标从 0 开始的节点,第 i 棵树必然满足下面的性质

  • 左子节点下标 2 * i + 1;
  • 右子节点下标 2 * i + 2

我们先来定义每一个节点的的类,用 left, right, sum 分别表示左右子结点和当前区间的和。

class Node {
    int left;
    int right;
    int sum;

    public Node (int left, int right, int sum) {
        this.left = left;
        this.right = right;
        this.sum = sum;
    }
}

下面就是线段树的实现,它的成员变量就是一个一维数组,用来保存这棵树的节点信息。为了保证空间足够,这里我们按照 num * 4 的规格来创建数组空间。

在得到树的节点之后我们需要对这棵树进行初始化,build 函数就是用来做树的初始化操作的,参数idx 表示当前递归的下标位置,最开始的下标从 0 开始, [l, r] 表示当前节点对应的区间,val 数组是最开始传入的信息。

在 build 函数里我们首先会判断当前节点是不是叶子节点,如果是叶子节点,那它的和就是参数原来的值;如果不是,则先找出这个区间的中点,然后分别向左子树和右子树递归建树。完成之后再将当前节点的 sum 值进行更新。

public class SegmentTree {
    Node[] nodes;

    public SegmentTree(int num, int[] val) {
        nodes = new Node[num * 4];
        for (int i = 0; i < nodes.length; i++) {
            nodes[i] = new Node(0, 0, 0);
        }
        build(0, 0, num - 1, val);
    }

    public void build(int idx, int l, int r, int[] val) {
        nodes[idx].left = l;
        nodes[idx].right = r;
        if (l == r) {
            nodes[idx].sum = val[l];
            return;
        }
        int mid = l + (r - l) / 2;
        if (l <= mid) {
            build(idx * 2 + 1, l, mid, val);
        }
        if (mid < r){
            build(idx * 2 + 2, mid + 1, r, val);
        }
        nodes[idx].sum = nodes[idx * 2 + 1].sum + nodes[idx * 2 + 2].sum;
    }
}

建完树之后如果我们希望查询某一个区间的和,那就要从根节点开始递归向下在线段树中查询。当递归到某一个节点的时候可能就会存在两种情况,

  • 当前节点所表示的区间是查询区间的子集,直接返回 sum;
  • 查询区间分别在当前节点表示区间的左右子树里,这时就需要分别计算左右子树的 sum,然后再把它们加起来。

下面代码就是这两种情况的实现

    public int query(int idx, int l, int r) {
        if ((l <= nodes[idx].left) && (nodes[idx].right <= r)) { // 当前节点表示的区间在询问的区间里面
            return nodes[idx].sum;
        }

        int sum = 0;
        int mid = nodes[idx].left + (nodes[idx].right - nodes[idx].left) / 2;
        if (l <= mid) { // 查询区间的左端点要在当前节点的左子树中去找
            sum += query(idx * 2 + 1, l, r);
        }
        if (r > mid) {
            sum += query(idx * 2 + 2, l, r);
        }
        return sum;
    }

我们考虑假如现在想要修改某一个节点所表示的值,该怎么做呢?很自然地我们应该想到先查到这个节点在树的哪个位置,然后把它的值改掉就行了嘛。确实是这样,但是我们同样得注意到我们的修改可能会改变这个节点所有父结点的 sum 值,所以我们每次修改都需要进行 pushUp,把它的父结点的 sum 值做一次更新。

下面就是修改操作的代码实现,其中 modify 函数和查询在写法上很相似,只是我们在每次对子结点进行修改之后都会对当前节点做 pushUp。

    public void pushUp(int idx) {
        nodes[idx].sum = nodes[idx * 2 + 1].sum + nodes[idx * 2 + 2].sum;
    }

    public void modify(int idx, int x, int val) { // x 表示要修改的元素下标, val 表示要修改的值
        if ((nodes[idx].left == x) && (nodes[idx].right == x)) {
            nodes[idx].sum = val;
            return;
        }

        int mid = nodes[idx].left + (nodes[idx].right - nodes[idx].left) / 2;
        if (x <= mid) {
            modify(idx * 2 + 1, x, val);
        } else {
            modify(idx * 2 + 2, x, val);
        }
        pushUp(idx);
    }

实战演练 #

下面我们用力扣第 307 题区域和检索 - 数组可修改 来说明线段树该怎么用,这道题题目和我们的例子几乎一模一样,这里不多做解释。

题目需要我们实现下面的模板,我们只需要把线段树套进去就可以顺利解决这道题目。

image-20230304160328605

如下面代码所示:

class NumArray {
    SegmentTree st;

    public NumArray(int[] nums) {
        st = new SegmentTree(nums.length, nums);
    }
    
    public void update(int index, int val) {
        st.modify(0, index, val);
    }
    
    public int sumRange(int left, int right) {
        return st.query(0, left, right);
    }
}

Last modified on 2023-03-03