143 lines
3.6 KiB
Plaintext
143 lines
3.6 KiB
Plaintext
#define NDEBUG
|
|
|
|
#include <stdlib.h>
|
|
#include <vector>
|
|
#include <iostream>
|
|
#include <limits>
|
|
|
|
template <typename T = long> class SegmentTree;
|
|
template <typename T = long> class SumSegmentTree;
|
|
template <typename T = long> class MinSegmentTree;
|
|
template <typename T = long> class MaxSegmentTree;
|
|
template <typename T = long> class DifferenceSegmentTree;
|
|
|
|
template <typename T>
|
|
class SumSegmentTree : public SegmentTree<T> {
|
|
virtual T cmp (T a, T b) override {
|
|
return a + b;
|
|
}
|
|
|
|
public:
|
|
SumSegmentTree(size_t N) : SegmentTree<T>(N, 0) {};
|
|
};
|
|
|
|
template <typename T>
|
|
class MinSegmentTree : public SegmentTree<T> {
|
|
virtual T cmp (T a, T b) override {
|
|
return (a < b) ? a : b;
|
|
}
|
|
|
|
public:
|
|
MinSegmentTree(size_t N) : SegmentTree<T>(N, std::numeric_limits<T>::max()) {};
|
|
};
|
|
|
|
template <typename T>
|
|
class MaxSegmentTree : public SegmentTree<T> {
|
|
virtual T cmp (T a, T b) override {
|
|
return (a > b) ? a : b;
|
|
}
|
|
|
|
public:
|
|
MaxSegmentTree(size_t N) : SegmentTree<T>(N, std::numeric_limits<T>::min()) {};
|
|
};
|
|
|
|
template <typename T>
|
|
class DifferenceSegmentTree : public SegmentTree<T> {
|
|
virtual T cmp (T a, T b) override {
|
|
return a+b;
|
|
}
|
|
|
|
public:
|
|
virtual void add(size_t i, T val) override {
|
|
this->set(i, this->get(i) + val);
|
|
if (i != this->N-1)
|
|
this->set(i+1, this->get(i+1) - val);
|
|
}
|
|
|
|
void add_range(size_t i, size_t j, T val) {
|
|
this->set(i, this->get(i) + val);
|
|
if (j+1 != this->N)
|
|
this->set(j+1, this->get(j+1)-val);
|
|
}
|
|
|
|
DifferenceSegmentTree(size_t N) : SegmentTree<T>(N, 0) {};
|
|
};
|
|
|
|
template <typename T>
|
|
class SegmentTree {
|
|
const T unit;
|
|
virtual T cmp (T a, T b) = 0;
|
|
std::vector<T> arr;
|
|
|
|
public:
|
|
size_t N;
|
|
|
|
void print() {
|
|
size_t newline = 2;
|
|
for (size_t i = 1; i < 2*N; i++) {
|
|
std::cout << arr[i] << " ";
|
|
if (i == newline-1) {
|
|
std::cout << std::endl;
|
|
newline *= 2;
|
|
}
|
|
}
|
|
}
|
|
|
|
SegmentTree(size_t N, T unit) : unit(unit) {
|
|
this->N = N;
|
|
// Initialize the array to the size of the smallest power of two greater than N
|
|
size_t exp = 0;
|
|
while (N != 0) {
|
|
N = N >> 1;
|
|
exp++;
|
|
}
|
|
if ((size_t)(1 << (exp-1)) == this->N)
|
|
arr = std::vector<T>(this->N*2, unit);
|
|
else
|
|
arr = std::vector<T>(1 << (exp+1), unit);
|
|
}
|
|
|
|
T get(size_t i) {
|
|
#ifndef NDEBUG
|
|
if (i >= N)
|
|
std::cerr << "Tried accessing index " << i << " out of " << N-1 << " (remember, 0-indexing)" << std::endl;
|
|
#endif
|
|
return arr[i+N];
|
|
}
|
|
/* T operator[](size_t i) { get(i); }; */
|
|
|
|
virtual void set(size_t i, T val) {
|
|
update_field(i, val);
|
|
}
|
|
|
|
void update_field(size_t i, T val) {
|
|
#ifndef NDEBUG
|
|
if (i >= N)
|
|
std::cerr << "Tried updating index " << i << " out of " << N-1 << " (remember, 0-indexing)" << std::endl;
|
|
#endif
|
|
i += N; // Put the index at the leaf nodes/original array
|
|
arr[i] = val;
|
|
for (i /= 2; i >= 1; i /= 2)
|
|
arr[i] = cmp(arr[i*2], arr[i*2+1]);
|
|
}
|
|
|
|
virtual void add(size_t i, T val) {
|
|
set(i, val + get(i));
|
|
}
|
|
|
|
T query(size_t a, size_t b) {
|
|
#ifndef NDEBUG
|
|
if (a > N || b > N || a > b)
|
|
std::cerr << "Tried querying the range: [" << a << ", " << b << "]" << " of max index " << N-1 << std::endl;
|
|
#endif
|
|
a += N; b += N;
|
|
size_t val = unit;
|
|
while (a <= b) {
|
|
if (a % 2 == 1) val = cmp(val, arr[a++]);
|
|
if (b % 2 == 0) val = cmp(val, arr[b--]);
|
|
a /= 2; b /= 2;
|
|
}
|
|
return val;
|
|
}
|
|
};
|