线段树
本文最后更新于 116 天前,如有失效请评论区留言。

先从一个例题出发
假设有一个数组array[0…n-1],里面有n个元素,现在我们要经常对这个数组做两件事:

更新数组元素的数值
求数组任意一段区间里元素的总和(或者平均值)

我有哪些方法实现上述操作?
1. 方法一:遍历一遍数组
时间复杂度:O(n)

  1. 方法二:线段树☑️
    时间复杂度:O(logn)

直接奉上线段树模板:

线段树:单点修改,区间查询

import typing

class SegTree:
    def __init__(
        self,
        op: typing.Callable[[typing.Any, typing.Any], typing.Any],
        e: typing.Any,
        v: typing.Union[int, typing.List[typing.Any]],
    ) -> None:
        """
        初始化线段树。
        :param op: 二元操作函数,用于定义线段树的合并操作,如 min, max 或 lambda。
        :param e: 操作的单位元,如对于 min 是 float('inf')。
        :param v: 初始数组,或者表示数组长度的整数,如果是整数则初始化为 e 的重复数组。
        """
        self._op = op  # 操作函数,如 min, max, lambda 等
        self._e = e  # 单位元,即操作的默认值
        if isinstance(v, int):
            v = [e] * v  # 如果 v 是一个整数,初始化 v 个单位元

        self._n = len(v)  # 原始数据的长度
        self._log = (self._n - 1).bit_length()  # 计算适用的最大层数
        self._size = 1 << self._log  # 线段树数组的大小
        self._d = [e] * (2 * self._size)  # 初始化线段树数组

        for i in range(self._n):
            self._d[self._size + i] = v[i]  # 填充初始数据
        for i in range(self._size - 1, 0, -1):
            self._update(i)  # 自底向上构建线段树

    def set(self, p: int, x: typing.Any) -> None:
        # 更新索引 p 的值为 x。
        assert 0 <= p < self._n
        p += self._size
        self._d[p] = x
        for i in range(1, self._log + 1):
            self._update(p >> i)  # 更新所有相关节点

    def get(self, p: int) -> typing.Any:
        # 获取索引 p 处的值。
        assert 0 <= p < self._n
        return self._d[p + self._size]

    def prod(self, left: int, right: int) -> typing.Any:
        # 计算区间 [left, right) 的聚合结果。
        assert 0 <= left <= right <= self._n
        sml = self._e  # 左区间的结果
        smr = self._e  # 右区间的结果
        left += self._size
        right += self._size
        while left < right:
            if left & 1:
                sml = self._op(sml, self._d[left])
                left += 1
            if right & 1:
                right -= 1
                smr = self._op(self._d[right], smr)
            left >>= 1
            right >>= 1
        return self._op(sml, smr)
        
    def all_prod(self) -> typing.Any:
       # 获取整个区间的聚合结果。
        return self._d[1]

    def max_right(self, left: int, f: typing.Callable[[typing.Any], bool]) -> int:
        """
        找到最左侧的索引,使得从 left 开始到该索引的聚合结果满足条件 f。
        :param left: 起始搜索的索引。
        :param f: 一个函数,定义聚合结果需要满足的条件。
        """
        assert 0 <= left <= self._n
        assert f(self._e)
        if left == self._n:
            return self._n

        left += self._size
        sm = self._e
        first = True

        while first or (left & -left) != left:
            first = False
            while left % 2 == 0:
                left >>= 1
            if not f(self._op(sm, self._d[left])):
                while left < self._size:
                    left *= 2
                    if f(self._op(sm, self._d[left])):
                        sm = self._op(sm, self._d[left])
                        left += 1
                return left - self._size
            sm = self._op(sm, self._d[left])
            left += 1
        return self._n

    def min_left(self, right: int, f: typing.Callable[[typing.Any], bool]) -> int:
        """
        找到最右侧的索引,使得从该索引到 right 的聚合结果满足条件 f。
        :param right: 结束搜索的索引。
        :param f: 一个函数,定义聚合结果需要满足的条件。
        """
        assert 0 <= right <= self._n
        assert f(self._e)
        if right == 0:
            return 0
        right += self._size
        sm = self._e
        first = True
        while first or (right & -right) != right:
            first = False
            right -= 1
            while right > 1 and right % 2:
                right >>= 1

            if not f(self._op(self._d[right], sm)):
                while right < self._size:
                    right = 2 * right + 1
                    if f(self._op(self._d[right], sm)):
                        sm = self._op(self._d[right], sm)
                        right -= 1
                return right + 1 - self._size
            sm = self._op(self._d[right], sm)
        return 0

    def _update(self, k: int) -> None:
        # 更新节点 k 的值
        self._d[k] = self._op(self._d[2 * k], self._d[2 * k + 1])

........
    n = min(5 * 10 ** 4, 3 * len(queries)) + 5
    st = SegTree(n)

lazy线段树:区间修改+lazy_tag标记

class LazySegmentTree:
    def __init__(self, nums: List[int]):
        # 使用数组 `nums` 初始化线段树。
        self.n = len(nums)  # 输入数组的长度。
        self.nums = nums  # 输入数组。
        self.ones = [0] * (2 << self.n.bit_length())  # 线段树节点,用于存储区间和。
        self.lazy = [False] * (2 << self.n.bit_length())  # 延迟传播标记,用于延迟更新。
        self._build(0, 0, self.n-1)  # 递归构建线段树。

    def _build(self, o: int, l: int, r: int) -> None:
        # 构建线段树节点。
        if l == r:
            # 叶子节点:直接存储数组中的值。
            self.ones[o] = self.nums[l]
            return
        left, right = 2 * o + 1, 2 * o + 2
        mid = (l + r) // 2
        self._build(left, l, mid)  # 递归构建左子树。
        self._build(right, mid + 1, r)  # 递归构建右子树。
        self.ones[o] = self.ones[left] + self.ones[right]  # 内部节点的值是其子节点的和。

    def _do(self, o: int, l: int, r: int) -> None:
        # 对节点应用延迟更新,翻转区间内的值。
        self.ones[o] = r - l + 1 - self.ones[o]  # 更新和为其在区间内的补数。
        self.lazy[o] = not self.lazy[o]  # 切换延迟标记。

    def _pushdown(self, o: int, l: int, r: int) -> None:
        # 将延迟更新向下传递至子节点。
        if self.lazy[o]:
            left, right = 2 * o + 1, 2 * o + 2
            mid = (l + r) // 2
            self._do(left, l, mid)
            self._do(right, mid + 1, r)
            self.lazy[o] = False

    def _update(self, o: int, l: int, r: int, ql: int, qr: int) -> None:
        # 更新区间 [ql, qr] 内的值。
        if ql <= l and r <= qr:
            self._do(o, l, r)
            return
        self._pushdown(o, l, r)
        left, right = 2 * o + 1, 2 * o + 2
        mid = (l + r) // 2
        if ql <= mid: self._update(left, l, mid, ql, qr)
        if mid + 1 <= qr: self._update(right, mid + 1, r, ql, qr)
        self.ones[o] = self.ones[left] + self.ones[right]

    def _query(self, o: int, l: int, r: int, ql: int, qr: int) -> int:
        # 查询区间 [ql, qr] 内的和。
        if ql <= l and r <= qr:
            return self.ones[o]
        self._pushdown(o, l, r)
        left, right = 2 * o + 1, 2 * o + 2
        mid = (l + r) // 2
        ans = 0
        if ql <= mid: ans += self._query(left, l, mid, ql, qr)
        if mid + 1 <= qr: ans += self._query(right, mid + 1, r, ql, qr)
        return ans

    def update(self, ql: int, qr: int) -> None:
        # 更新公共接口,更新区间 [ql, qr]。
        self._update(0, 0, self.n-1, ql, qr)

    def query(self, ql: int, qr: int) -> int:
        # 查询公共接口,查询区间 [ql, qr] 的和。
        return self._query(0, 0, self.n-1, ql, qr)
版权声明:除特殊说明,博客文章均为大块肌原创,依据CC BY-SA 4.0许可证进行授权,转载请附上出处链接及本声明。
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇