本文最后更新于 116 天前,如有失效请评论区留言。
先从一个例题出发
假设有一个数组array[0…n-1],里面有n个元素,现在我们要经常对这个数组做两件事:
更新数组元素的数值
求数组任意一段区间里元素的总和(或者平均值)
我有哪些方法实现上述操作?
1. 方法一:遍历一遍数组
时间复杂度:O(n)
- 方法二:线段树☑️
时间复杂度:O(logn)
直接奉上线段树模板:
线段树:单点修改,区间查询
import typing
class SegTree:
def __init__(
self,
op: typing.Callable[[typing.Any, typing.Any], typing.Any],
e: typing.Any,
v: typing.Union[int, typing.List[typing.Any]],
) -> None:
"""
初始化线段树。
:param op: 二元操作函数,用于定义线段树的合并操作,如 min, max 或 lambda。
:param e: 操作的单位元,如对于 min 是 float('inf')。
:param v: 初始数组,或者表示数组长度的整数,如果是整数则初始化为 e 的重复数组。
"""
self._op = op # 操作函数,如 min, max, lambda 等
self._e = e # 单位元,即操作的默认值
if isinstance(v, int):
v = [e] * v # 如果 v 是一个整数,初始化 v 个单位元
self._n = len(v) # 原始数据的长度
self._log = (self._n - 1).bit_length() # 计算适用的最大层数
self._size = 1 << self._log # 线段树数组的大小
self._d = [e] * (2 * self._size) # 初始化线段树数组
for i in range(self._n):
self._d[self._size + i] = v[i] # 填充初始数据
for i in range(self._size - 1, 0, -1):
self._update(i) # 自底向上构建线段树
def set(self, p: int, x: typing.Any) -> None:
# 更新索引 p 的值为 x。
assert 0 <= p < self._n
p += self._size
self._d[p] = x
for i in range(1, self._log + 1):
self._update(p >> i) # 更新所有相关节点
def get(self, p: int) -> typing.Any:
# 获取索引 p 处的值。
assert 0 <= p < self._n
return self._d[p + self._size]
def prod(self, left: int, right: int) -> typing.Any:
# 计算区间 [left, right) 的聚合结果。
assert 0 <= left <= right <= self._n
sml = self._e # 左区间的结果
smr = self._e # 右区间的结果
left += self._size
right += self._size
while left < right:
if left & 1:
sml = self._op(sml, self._d[left])
left += 1
if right & 1:
right -= 1
smr = self._op(self._d[right], smr)
left >>= 1
right >>= 1
return self._op(sml, smr)
def all_prod(self) -> typing.Any:
# 获取整个区间的聚合结果。
return self._d[1]
def max_right(self, left: int, f: typing.Callable[[typing.Any], bool]) -> int:
"""
找到最左侧的索引,使得从 left 开始到该索引的聚合结果满足条件 f。
:param left: 起始搜索的索引。
:param f: 一个函数,定义聚合结果需要满足的条件。
"""
assert 0 <= left <= self._n
assert f(self._e)
if left == self._n:
return self._n
left += self._size
sm = self._e
first = True
while first or (left & -left) != left:
first = False
while left % 2 == 0:
left >>= 1
if not f(self._op(sm, self._d[left])):
while left < self._size:
left *= 2
if f(self._op(sm, self._d[left])):
sm = self._op(sm, self._d[left])
left += 1
return left - self._size
sm = self._op(sm, self._d[left])
left += 1
return self._n
def min_left(self, right: int, f: typing.Callable[[typing.Any], bool]) -> int:
"""
找到最右侧的索引,使得从该索引到 right 的聚合结果满足条件 f。
:param right: 结束搜索的索引。
:param f: 一个函数,定义聚合结果需要满足的条件。
"""
assert 0 <= right <= self._n
assert f(self._e)
if right == 0:
return 0
right += self._size
sm = self._e
first = True
while first or (right & -right) != right:
first = False
right -= 1
while right > 1 and right % 2:
right >>= 1
if not f(self._op(self._d[right], sm)):
while right < self._size:
right = 2 * right + 1
if f(self._op(self._d[right], sm)):
sm = self._op(self._d[right], sm)
right -= 1
return right + 1 - self._size
sm = self._op(self._d[right], sm)
return 0
def _update(self, k: int) -> None:
# 更新节点 k 的值
self._d[k] = self._op(self._d[2 * k], self._d[2 * k + 1])
........
n = min(5 * 10 ** 4, 3 * len(queries)) + 5
st = SegTree(n)
lazy线段树:区间修改+lazy_tag标记
class LazySegmentTree:
def __init__(self, nums: List[int]):
# 使用数组 `nums` 初始化线段树。
self.n = len(nums) # 输入数组的长度。
self.nums = nums # 输入数组。
self.ones = [0] * (2 << self.n.bit_length()) # 线段树节点,用于存储区间和。
self.lazy = [False] * (2 << self.n.bit_length()) # 延迟传播标记,用于延迟更新。
self._build(0, 0, self.n-1) # 递归构建线段树。
def _build(self, o: int, l: int, r: int) -> None:
# 构建线段树节点。
if l == r:
# 叶子节点:直接存储数组中的值。
self.ones[o] = self.nums[l]
return
left, right = 2 * o + 1, 2 * o + 2
mid = (l + r) // 2
self._build(left, l, mid) # 递归构建左子树。
self._build(right, mid + 1, r) # 递归构建右子树。
self.ones[o] = self.ones[left] + self.ones[right] # 内部节点的值是其子节点的和。
def _do(self, o: int, l: int, r: int) -> None:
# 对节点应用延迟更新,翻转区间内的值。
self.ones[o] = r - l + 1 - self.ones[o] # 更新和为其在区间内的补数。
self.lazy[o] = not self.lazy[o] # 切换延迟标记。
def _pushdown(self, o: int, l: int, r: int) -> None:
# 将延迟更新向下传递至子节点。
if self.lazy[o]:
left, right = 2 * o + 1, 2 * o + 2
mid = (l + r) // 2
self._do(left, l, mid)
self._do(right, mid + 1, r)
self.lazy[o] = False
def _update(self, o: int, l: int, r: int, ql: int, qr: int) -> None:
# 更新区间 [ql, qr] 内的值。
if ql <= l and r <= qr:
self._do(o, l, r)
return
self._pushdown(o, l, r)
left, right = 2 * o + 1, 2 * o + 2
mid = (l + r) // 2
if ql <= mid: self._update(left, l, mid, ql, qr)
if mid + 1 <= qr: self._update(right, mid + 1, r, ql, qr)
self.ones[o] = self.ones[left] + self.ones[right]
def _query(self, o: int, l: int, r: int, ql: int, qr: int) -> int:
# 查询区间 [ql, qr] 内的和。
if ql <= l and r <= qr:
return self.ones[o]
self._pushdown(o, l, r)
left, right = 2 * o + 1, 2 * o + 2
mid = (l + r) // 2
ans = 0
if ql <= mid: ans += self._query(left, l, mid, ql, qr)
if mid + 1 <= qr: ans += self._query(right, mid + 1, r, ql, qr)
return ans
def update(self, ql: int, qr: int) -> None:
# 更新公共接口,更新区间 [ql, qr]。
self._update(0, 0, self.n-1, ql, qr)
def query(self, ql: int, qr: int) -> int:
# 查询公共接口,查询区间 [ql, qr] 的和。
return self._query(0, 0, self.n-1, ql, qr)