Non Recursive Segment Tree

"""
A non-recursive Segment Tree implementation with range query and single element update,
works virtually with any list of the same type of elements with a "commutative"
combiner.

Explanation:
https://www.geeksforgeeks.org/iterative-segment-tree-range-minimum-query/
https://www.geeksforgeeks.org/segment-tree-efficient-implementation/

>>> SegmentTree([1, 2, 3], lambda a, b: a + b).query(0, 2)
6
>>> SegmentTree([3, 1, 2], min).query(0, 2)
1
>>> SegmentTree([2, 3, 1], max).query(0, 2)
3
>>> st = SegmentTree([1, 5, 7, -1, 6], lambda a, b: a + b)
>>> st.update(1, -1)
>>> st.update(2, 3)
>>> st.query(1, 2)
2
>>> st.query(1, 1)
-1
>>> st.update(4, 1)
>>> st.query(3, 4)
0
>>> st = SegmentTree([[1, 2, 3], [3, 2, 1], [1, 1, 1]], lambda a, b: [a[i] + b[i] for i
...                                                                   in range(len(a))])
>>> st.query(0, 1)
[4, 4, 4]
>>> st.query(1, 2)
[4, 3, 2]
>>> st.update(1, [-1, -1, -1])
>>> st.query(1, 2)
[0, 0, 0]
>>> st.query(0, 2)
[1, 2, 3]
"""
from __future__ import annotations

from typing import Any, Callable, Generic, TypeVar

T = TypeVar("T")


class SegmentTree(Generic[T]):
    def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None:
        """
        Segment Tree constructor, it works just with commutative combiner.
        :param arr: list of elements for the segment tree
        :param fnc: commutative function for combine two elements

        >>> SegmentTree(['a', 'b', 'c'], lambda a, b: f'{a}{b}').query(0, 2)
        'abc'
        >>> SegmentTree([(1, 2), (2, 3), (3, 4)],
        ...             lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2)
        (6, 9)
        """
        any_type: Any | T = None

        self.N: int = len(arr)
        self.st: list[T] = [any_type for _ in range(self.N)] + arr
        self.fn = fnc
        self.build()

    def build(self) -> None:
        for p in range(self.N - 1, 0, -1):
            self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])

    def update(self, p: int, v: T) -> None:
        """
        Update an element in log(N) time
        :param p: position to be update
        :param v: new value

        >>> st = SegmentTree([3, 1, 2, 4], min)
        >>> st.query(0, 3)
        1
        >>> st.update(2, -1)
        >>> st.query(0, 3)
        -1
        """
        p += self.N
        self.st[p] = v
        while p > 1:
            p = p // 2
            self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])

    def query(self, l: int, r: int) -> T | None:  # noqa: E741
        """
        Get range query value in log(N) time
        :param l: left element index
        :param r: right element index
        :return: element combined in the range [l, r]

        >>> st = SegmentTree([1, 2, 3, 4], lambda a, b: a + b)
        >>> st.query(0, 2)
        6
        >>> st.query(1, 2)
        5
        >>> st.query(0, 3)
        10
        >>> st.query(2, 3)
        7
        """
        l, r = l + self.N, r + self.N  # noqa: E741

        res: T | None = None
        while l <= r:  # noqa: E741
            if l % 2 == 1:
                res = self.st[l] if res is None else self.fn(res, self.st[l])
            if r % 2 == 0:
                res = self.st[r] if res is None else self.fn(res, self.st[r])
            l, r = (l + 1) // 2, (r - 1) // 2
        return res


if __name__ == "__main__":
    from functools import reduce

    test_array = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12]

    test_updates = {
        0: 7,
        1: 2,
        2: 6,
        3: -14,
        4: 5,
        5: 4,
        6: 7,
        7: -10,
        8: 9,
        9: 10,
        10: 12,
        11: 1,
    }

    min_segment_tree = SegmentTree(test_array, min)
    max_segment_tree = SegmentTree(test_array, max)
    sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b)

    def test_all_segments() -> None:
        """
        Test all possible segments
        """
        for i in range(len(test_array)):
            for j in range(i, len(test_array)):
                min_range = reduce(min, test_array[i : j + 1])
                max_range = reduce(max, test_array[i : j + 1])
                sum_range = reduce(lambda a, b: a + b, test_array[i : j + 1])
                assert min_range == min_segment_tree.query(i, j)
                assert max_range == max_segment_tree.query(i, j)
                assert sum_range == sum_segment_tree.query(i, j)

    test_all_segments()

    for index, value in test_updates.items():
        test_array[index] = value
        min_segment_tree.update(index, value)
        max_segment_tree.update(index, value)
        sum_segment_tree.update(index, value)
        test_all_segments()
Algerlogo

Β© Alger 2022

About us

We are a group of programmers helping each other build new things, whether it be writing complex encryption programs, or simple ciphers. Our goal is to work together to document and model beautiful, helpful and interesting algorithms using code. We are an open-source community - anyone can contribute. We check each other's work, communicate and collaborate to solve problems. We strive to be welcoming, respectful, yet make sure that our code follows the latest programming guidelines.