vim_snippets/CPP/range_queries/SegmentTree.cpp.snip

143 lines
3.6 KiB
Plaintext
Raw Permalink Normal View History

2024-03-21 13:49:58 +01:00
#define NDEBUG
#include <stdlib.h>
#include <vector>
#include <iostream>
#include <limits>
template <typename T = long> class SegmentTree;
template <typename T = long> class SumSegmentTree;
template <typename T = long> class MinSegmentTree;
template <typename T = long> class MaxSegmentTree;
template <typename T = long> class DifferenceSegmentTree;
template <typename T>
class SumSegmentTree : public SegmentTree<T> {
virtual T cmp (T a, T b) override {
return a + b;
}
public:
SumSegmentTree(size_t N) : SegmentTree<T>(N, 0) {};
};
template <typename T>
class MinSegmentTree : public SegmentTree<T> {
virtual T cmp (T a, T b) override {
return (a < b) ? a : b;
}
public:
MinSegmentTree(size_t N) : SegmentTree<T>(N, std::numeric_limits<T>::max()) {};
};
template <typename T>
class MaxSegmentTree : public SegmentTree<T> {
virtual T cmp (T a, T b) override {
return (a > b) ? a : b;
}
public:
MaxSegmentTree(size_t N) : SegmentTree<T>(N, std::numeric_limits<T>::min()) {};
};
template <typename T>
class DifferenceSegmentTree : public SegmentTree<T> {
virtual T cmp (T a, T b) override {
return a+b;
}
public:
virtual void add(size_t i, T val) override {
this->set(i, this->get(i) + val);
if (i != this->N-1)
this->set(i+1, this->get(i+1) - val);
}
void add_range(size_t i, size_t j, T val) {
this->set(i, this->get(i) + val);
if (j+1 != this->N)
this->set(j+1, this->get(j+1)-val);
}
DifferenceSegmentTree(size_t N) : SegmentTree<T>(N, 0) {};
};
template <typename T>
class SegmentTree {
const T unit;
virtual T cmp (T a, T b) = 0;
std::vector<T> arr;
public:
size_t N;
void print() {
size_t newline = 2;
for (size_t i = 1; i < 2*N; i++) {
std::cout << arr[i] << " ";
if (i == newline-1) {
std::cout << std::endl;
newline *= 2;
}
}
}
SegmentTree(size_t N, T unit) : unit(unit) {
this->N = N;
// Initialize the array to the size of the smallest power of two greater than N
size_t exp = 0;
while (N != 0) {
N = N >> 1;
exp++;
}
if ((size_t)(1 << (exp-1)) == this->N)
arr = std::vector<T>(this->N*2, unit);
else
arr = std::vector<T>(1 << (exp+1), unit);
}
T get(size_t i) {
#ifndef NDEBUG
if (i >= N)
std::cerr << "Tried accessing index " << i << " out of " << N-1 << " (remember, 0-indexing)" << std::endl;
#endif
return arr[i+N];
}
/* T operator[](size_t i) { get(i); }; */
virtual void set(size_t i, T val) {
update_field(i, val);
}
void update_field(size_t i, T val) {
#ifndef NDEBUG
if (i >= N)
std::cerr << "Tried updating index " << i << " out of " << N-1 << " (remember, 0-indexing)" << std::endl;
#endif
i += N; // Put the index at the leaf nodes/original array
arr[i] = val;
for (i /= 2; i >= 1; i /= 2)
arr[i] = cmp(arr[i*2], arr[i*2+1]);
}
virtual void add(size_t i, T val) {
set(i, val + get(i));
}
T query(size_t a, size_t b) {
#ifndef NDEBUG
if (a > N || b > N || a > b)
std::cerr << "Tried querying the range: [" << a << ", " << b << "]" << " of max index " << N-1 << std::endl;
#endif
a += N; b += N;
size_t val = unit;
while (a <= b) {
if (a % 2 == 1) val = cmp(val, arr[a++]);
if (b % 2 == 0) val = cmp(val, arr[b--]);
a /= 2; b /= 2;
}
return val;
}
};