This documentation is automatically generated by online-judge-tools/verification-helper
#include "cpstl/ds/Segtree.hpp"
列に対して単一要素の更新・区間集約が高速に行えるデータ構造.非再帰・列のサイズを $2$ のべき乗に合わせる実装.以下,管理する列を A
,列の長さを N
とする.
template <typename S, auto op, auto e>
S
op
S op(S, S)
の形で定義する (lambda 式でも OK)e
S e()
の形で定義する.op
は結合法則を満たす必要がある.すなわち,任意の S
の元 a, b, c
に対し op(op(a, b), c) = op(a, op(b, c))
となる必要がある.
以下,op
で A[l], A[l + 1], ..., A[r]
を集約した値を op[l, r]
,A[l], A[l + 1], ..., A[r - 1]
を集約した値を op[l, r)
で表す.
(1) Segtree()
(2) Segtree(int N)
(3) Segtree(int N, const S &init)
(4) Segtree(const std::vector<S> &v)
(5) template <class Inputit> Segtree(Inputit first, Inputit last)
Segtree
を作成する.N
,初期値 e()
で初期化.N
,初期値 init
で初期化.v
で初期化.[first, last)
の値で初期化.void set(int pos, const S &x)
A[pos] = x
で更新.
void add(int pos, const S &x)
A[pos] += x
で更新.S
に operator+=
が定義されている必要がある.
template <typename F, auto mapping> void set(int pos, const F &f)
A[pos] = mapping(f, A[pos])
で更新.mapping
は S mapping(F, S)
の形で定義されている必要がある (lambda 式でも OK).
const S& get(int pos) const
A[pos]
を返す.
const S& get(int pos) const noexcept
A[pos]
を返す.
S fold(int l, int r) const
op[l, r)
を返す.
S all_fold() const
op[1, N]
を返す.
<typename F> int max_right(int l, const F& f)
r = l
または f(op[l, r)) = true
r = n
または f(op[l, r]) = false
これらを両方満たす r
を返す.f
が単調な場合,f(op[l, r)) = true
となる最大の r
が返ってくる.
<typename F> int min_left(int r, const F& f) const
l = r
または f(op[l, r)) = true
l = 0
または f(op[l - 1, r)) = false
これらを両方満たす r
を返す.f
が単調な場合,f(op[l, r)) = true
となる最小の l
が返ってくる.
#pragma once
#include "cpstl/other/Template.hpp"
namespace cpstd {
// @brief Segment Tree
template <
typename S,
auto op,
auto e
>
class Segtree {
private:
std::vector<S> dat;
int N, sz;
public:
Segtree() {}
explicit Segtree(int n) : Segtree(std::vector<S>(n, e())) {}
explicit Segtree(int n, const S &init) : Segtree(std::vector<S>(n, init)) {}
explicit Segtree(const std::vector<S> &v) : N((int)v.size()) {
sz = 1;
while (sz < N) sz <<= 1;
dat.assign(sz << 1, e());
for (int i = 0; i < N; ++i) dat[i + sz] = v[i];
for (int i = sz - 1; i >= 1; --i) dat[i] = op(dat[i << 1], dat[i << 1 | 1]);
}
template <class Inputit>
Segtree(Inputit first, Inputit last) : Segtree(std::vector<S>(first, last)) {}
// A[pos] ← x で更新
// O(logN) time
void set(int pos, const S &x) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] = x;
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] ← A[pos] + x で更新
// O(logN) time
void add(int pos, const S &x) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] += x;
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] ← mapping(f, A[pos]) で更新
// O(logN) time
template <
typename F,
auto mapping
>
void set(int pos, const F &f) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] = mapping(f, dat[pos]);
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] を返す
// O(1) time
const S& get(int pos) const {
assert(0 <= pos && pos < N);
return dat[pos + sz];
}
// A[pos] を返す (assert なし)
// O(1) time
const S& operator[](int pos) const noexcept { return dat[pos + sz]; }
// op[l, r) を返す
// O(logN) time
S fold(int l, int r) const {
assert(0 <= l && l <= r && r <= N);
if (l == r) return e();
S resl = e(), resr = e();
for (l += sz, r += sz; l < r; l >>= 1, r >>= 1) {
if (l & 1) resl = op(resl, dat[l++]);
if (r & 1) resr = op(dat[--r], resr);
}
return op(resl, resr);
}
// op[1, N] を返す
// O(1) time
S all_fold() const { return dat[1]; }
// `r = l` または `f(op[l, r)) = true`
// `r = n` または `f(op[l, r]) = false`
// これらを両方満たす `r` を返す (`f` が単調なら `f(op[l, r)) = true` となる最大の `r`)
// O(logN) time
template <typename F>
int max_right(int l, const F& f) const {
assert(0 <= l && l <= N);
assert(f(e()));
if (l == N) return N;
l += sz;
S s = e();
do {
while (!(l & 1)) l >>= 1;
if (!f(op(s, dat[l]))) {
while (l < sz) {
l <<= 1;
if (f(op(s, dat[l]))) s = op(s, dat[l++]);
}
return l - sz;
}
s = op(s, dat[l++]);
} while ((l & -l) != l);
return N;
}
// `l = r` または `f(op[l, r)) = true`
// `l = 0` または `f(op[l - 1, r)) = false`
// これらを両方満たす `l` を返す (`f` が単調なら `f(op[l, r)) = true` となる最小の `l`)
// O(logN) time
template <typename F>
int min_left(int r, const F &f) const {
assert(0 <= r && r <= N);
assert(f(e()));
if (r == 0) return 0;
r += sz;
S s = e();
do {
--r;
while (r > 1 && (r & 1)) r >>= 1;
if (!f(op(dat[r], s))) {
while (r < sz) {
r = r << 1 | 1;
if (f(op(dat[r], s))) s = op(dat[r--], s);
}
return r + 1 - sz;
}
s = op(dat[r], s);
} while ((r & -r) != r);
return 0;
}
};
};
#line 2 "cpstl/ds/Segtree.hpp"
#line 2 "cpstl/other/Template.hpp"
#include <immintrin.h>
#include <algorithm>
#include <array>
#include <bit>
#include <bitset>
#include <cassert>
#include <cctype>
#include <cfenv>
#include <charconv>
#include <chrono>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <complex>
#include <cstdarg>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <initializer_list>
#include <iomanip>
#include <ios>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <streambuf>
#include <string>
#include <tuple>
#include <type_traits>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#line 2 "cpstl/math/StaticModint.hpp"
#line 4 "cpstl/math/StaticModint.hpp"
namespace cpstd {
// @brief Static Modint
// @see https://hackmd.io/@tatyam-prime/rkVCOcwQn
template <uint32_t m>
struct StaticModint {
private:
using mint = StaticModint;
uint32_t _v = 0;
static constexpr bool is_prime = []() -> bool {
if (m == 1) return false;
if (m == 2 || m == 7 || m == 61) return true;
if (!(m & 1)) return false;
uint32_t d = m - 1;
while (!(d & 1)) d >>= 1;
for (uint32_t a : {2, 7, 61}) {
uint32_t t = d;
mint y = mint(a).pow(t);
while (t != m - 1 && y != 1 && y != m - 1) {
y *= y;
t <<= 1;
}
if (y != m - 1 && !(t & 1)) return false;
}
return true;
}();
static constexpr std::pair<int32_t, int32_t> inv_gcd(int32_t a, int32_t b) {
if (a == 0) return {b, 0};
int32_t s = b, t = a, m0 = 0, m1 = 1;
while (t) {
const int32_t q = s / t;
s -= t * q, std::swap(s, t);
m0 -= m1 * q, std::swap(m0, m1);
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
public:
constexpr StaticModint() {}
template <typename T>
constexpr StaticModint(T v) {
static_assert(std::is_integral_v<T>, "T is not integral type.");
if constexpr (std::is_signed_v<T>) {
int64_t x = int64_t(v % int64_t(m));
if (x < 0) x += m;
_v = uint32_t(x);
}
else _v = uint32_t(v % m);
}
static constexpr mint raw(uint32_t v) { mint a; a._v = v; return a; }
static constexpr uint32_t mod() { return m; }
constexpr uint32_t val() const { return _v; }
constexpr mint& operator++() { return *this += 1; }
constexpr mint operator++(int) { mint res = *this; ++*this; return res; }
constexpr mint& operator--() { return *this -= 1; }
constexpr mint operator--(int) { mint res = *this; --*this; return res; }
constexpr mint& operator+=(mint rhs) {
if (_v >= m - rhs._v) _v -= m;
_v += rhs._v;
return *this;
}
constexpr mint& operator-=(mint rhs) {
if (_v < rhs._v) _v += m;
_v -= rhs._v;
return *this;
}
constexpr mint& operator*=(mint rhs) { return *this = *this * rhs; }
constexpr mint& operator/=(mint rhs) { return *this *= rhs.inv(); }
constexpr mint operator+() const { return *this; }
constexpr mint operator-() const { return mint{} - *this; }
constexpr mint pow(long long n) const {
assert(0 <= n);
if (n == 0) return 1;
mint x = *this, r = 1;
while (n > 0) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
if (!n) return r;
}
return r;
}
constexpr mint inv() const {
if constexpr (is_prime) {
assert(_v);
return pow(m - 2);
}
else {
auto eg = inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}
friend constexpr mint operator+(mint lhs, mint rhs) { return lhs += rhs; }
friend constexpr mint operator-(mint lhs, mint rhs) { return lhs -= rhs; }
friend constexpr mint operator*(mint lhs, mint rhs) { return uint64_t(lhs._v) * rhs._v; }
friend constexpr mint operator/(mint lhs, mint rhs) { return lhs /= rhs; }
friend constexpr bool operator==(mint lhs, mint rhs) { return lhs._v == rhs._v; }
friend constexpr bool operator!=(mint lhs, mint rhs) { return lhs._v != rhs._v; }
};
using Modint998244353 = StaticModint<998244353>;
constexpr Modint998244353 operator""_M(unsigned long long x) { return x; }
};
#line 2 "cpstl/other/Fastio.hpp"
#line 4 "cpstl/other/Fastio.hpp"
namespace cpstd {
// @brief Fast I/O
// @see https://judge.yosupo.jp/submission/21623
// @see https://maspypy.com/library-checker-many-a-b
namespace Fastio {
static constexpr const uint32_t BUF_SIZE = 1 << 17;
char ibuf[BUF_SIZE], obuf[BUF_SIZE], out[100];
uint32_t pil = 0, pir = 0, por = 0;
struct Pre {
char num[10000][4];
constexpr Pre() : num() {
for (int i = 0; i < 10000; ++i) {
int n = i;
for (int j = 3; j >= 0; --j) {
num[i][j] = n % 10 | '0';
n /= 10;
}
}
}
} constexpr pre;
inline void load() {
std::memcpy(ibuf, ibuf + pil, pir - pil);
pir = pir - pil + std::fread(ibuf + pir - pil, 1, BUF_SIZE - pir + pil, stdin);
pil = 0;
if (pir < BUF_SIZE) ibuf[pir++] = '\n';
}
inline void flush() {
fwrite(obuf, 1, por, stdout);
por = 0;
}
void _input(char &dest) {
do {
if (pil + 1 > pir) load();
dest = ibuf[pil++];
} while (std::isspace(dest));
}
void _input(std::string &dest) {
dest.clear();
char c;
do {
if (pil + 1 > pir) load();
c = ibuf[pil++];
} while (std::isspace(c));
do {
dest += c;
if (pil == pir) load();
c = ibuf[pil++];
} while (!std::isspace(c));
}
void _input(float &dest) {
std::string s;
_input(s);
dest = std::stof(s);
}
void _input(double &dest) {
std::string s;
_input(s);
dest = std::stod(s);
}
void _input(long double &dest) {
std::string s;
_input(s);
dest = std::stold(s);
}
template <typename T>
void input_int(T &x) {
if (pil + 100 > pir) load();
char c;
do {
c = ibuf[pil++];
} while (c < '-');
bool minus = 0;
if constexpr (std::is_signed<T>::value || std::is_same_v<T, __int128_t>) {
if (c == '-') minus = 1, c = ibuf[pil++];
}
x = 0;
while (c >= '0') x = x * 10 + (c & 15), c = ibuf[pil++];
if constexpr (std::is_signed<T>::value || std::is_same_v<T, __int128_t>) {
if (minus) x = -x;
}
}
void _input(int &dest) { input_int(dest); }
void _input(unsigned int &dest) { input_int(dest); }
void _input(long long &dest) { input_int(dest); }
void _input(unsigned long long &dest) { input_int(dest); }
void _input(__int128 &dest) { input_int(dest); }
void _input(unsigned __int128 &dest) { input_int(dest); }
template <uint32_t m>
void _input(cpstd::StaticModint<m> &dest) { long long a; _input(a); dest = a; }
template <typename T, typename U>
void _input(std::pair<T, U> &dest) { _input(dest.first), _input(dest.second); }
template <std::size_t N = 0, typename T>
void input_tuple(T &t) {
if constexpr (N < std::tuple_size<T>::value) {
auto &x = std::get<N>(t);
input(x);
input_tuple<N + 1>(t);
}
}
template <typename... T>
void _input(std::tuple<T...> &dest) { input_tuple(dest); }
template <std::size_t N = 0, typename T>
void _input(std::array<T, N> &dest) { for (auto &e : dest) _input(e); }
template <typename T>
void _input(std::vector<T> &dest) { for (auto &e : dest) _input(e); }
void input() {}
// 各引数に入力
template <typename H, typename... T>
void input(H &desth, T &... destt) { _input(desth), input(destt...); }
void _print(const char tg) {
if (por == BUF_SIZE) flush();
obuf[por++] = tg;
}
void _print(const std::string tg) { for (char c : tg) _print(c); }
void _print(const char *tg) {
std::size_t len = std::strlen(tg);
for (std::size_t i = 0; i < len; ++i) _print(tg[i]);
}
template <typename T>
void print_int(T x) {
if (por > BUF_SIZE - 100) flush();
if (x < 0) obuf[por++] = '-', x = -x;
int outi;
for (outi = 96; x >= 10000; outi -= 4) {
std::memcpy(out + outi, pre.num[x % 10000], 4);
x /= 10000;
}
if (x >= 1000) {
std::memcpy(obuf + por, pre.num[x], 4);
por += 4;
}
else if (x >= 100) {
std::memcpy(obuf + por, pre.num[x] + 1, 3);
por += 3;
}
else if (x >= 10) {
int q = (x * 103) >> 10;
obuf[por] = q | '0';
obuf[por + 1] = (x - q * 10) | '0';
por += 2;
}
else obuf[por++] = x | '0';
std::memcpy(obuf + por, out + outi + 4, 96 - outi);
por += 96 - outi;
}
template <typename T>
void print_real(T tg) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(15) << double(tg);
std::string s = oss.str();
_print(s);
}
void _print(int tg) { print_int(tg); }
void _print(unsigned int tg) { print_int(tg); }
void _print(long long tg) { print_int(tg); }
void _print(unsigned long long tg) { print_int(tg); }
void _print(__int128 tg) { print_int(tg); }
void _print(unsigned __int128 tg) { print_int(tg); }
void _print(float tg) { print_real(tg); }
void _print(double tg) { print_real(tg); }
void _print(long double tg) { print_real(tg); }
template <uint32_t m>
void _print(cpstd::StaticModint<m> tg) { print_int(tg.val()); }
template <typename T, typename U>
void _print(const std::pair<T, U> tg) {
_print(tg.first);
_print(' ');
_print(tg.second);
}
template <std::size_t N = 0, typename T>
void print_tuple(const T tg) {
if constexpr (N < std::tuple_size<T>::value) {
if constexpr (N > 0) _print(' ');
const auto x = std::get<N>(tg);
_print(x);
print_tuple<N + 1>(tg);
}
}
template <typename... T>
void _print(std::tuple<T...> tg) { print_tuple(tg); }
template <typename T, std::size_t N>
void _print(const std::array<T, N> tg) {
auto len = tg.size();
for (std::size_t i = 0; i < len; ++i) {
if (i) _print(' ');
_print(tg[i]);
}
}
template <typename T>
void _print(const std::vector<T> tg) {
auto len = tg.size();
for (std::size_t i = 0; i < len; ++i) {
if (i) _print(' ');
_print(tg[i]);
}
}
void print() { _print('\n'); }
// 各引数を空白区切りで出力し改行
template <typename H, typename... T>
void print(H &&tgh, T &&... tgt) {
_print(tgh);
if (sizeof...(tgt)) _print(' ');
print(std::forward<T>(tgt)...);
}
void __attribute__((destructor)) _d() { flush(); }
};
using Fastio::input;
using Fastio::print;
using Fastio::flush;
};
#line 4 "cpstl/ds/Segtree.hpp"
namespace cpstd {
// @brief Segment Tree
template <
typename S,
auto op,
auto e
>
class Segtree {
private:
std::vector<S> dat;
int N, sz;
public:
Segtree() {}
explicit Segtree(int n) : Segtree(std::vector<S>(n, e())) {}
explicit Segtree(int n, const S &init) : Segtree(std::vector<S>(n, init)) {}
explicit Segtree(const std::vector<S> &v) : N((int)v.size()) {
sz = 1;
while (sz < N) sz <<= 1;
dat.assign(sz << 1, e());
for (int i = 0; i < N; ++i) dat[i + sz] = v[i];
for (int i = sz - 1; i >= 1; --i) dat[i] = op(dat[i << 1], dat[i << 1 | 1]);
}
template <class Inputit>
Segtree(Inputit first, Inputit last) : Segtree(std::vector<S>(first, last)) {}
// A[pos] ← x で更新
// O(logN) time
void set(int pos, const S &x) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] = x;
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] ← A[pos] + x で更新
// O(logN) time
void add(int pos, const S &x) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] += x;
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] ← mapping(f, A[pos]) で更新
// O(logN) time
template <
typename F,
auto mapping
>
void set(int pos, const F &f) {
assert(0 <= pos && pos < N);
pos += sz;
dat[pos] = mapping(f, dat[pos]);
while (pos > 1) {
pos >>= 1;
dat[pos] = op(dat[pos << 1], dat[pos << 1 | 1]);
}
}
// A[pos] を返す
// O(1) time
const S& get(int pos) const {
assert(0 <= pos && pos < N);
return dat[pos + sz];
}
// A[pos] を返す (assert なし)
// O(1) time
const S& operator[](int pos) const noexcept { return dat[pos + sz]; }
// op[l, r) を返す
// O(logN) time
S fold(int l, int r) const {
assert(0 <= l && l <= r && r <= N);
if (l == r) return e();
S resl = e(), resr = e();
for (l += sz, r += sz; l < r; l >>= 1, r >>= 1) {
if (l & 1) resl = op(resl, dat[l++]);
if (r & 1) resr = op(dat[--r], resr);
}
return op(resl, resr);
}
// op[1, N] を返す
// O(1) time
S all_fold() const { return dat[1]; }
// `r = l` または `f(op[l, r)) = true`
// `r = n` または `f(op[l, r]) = false`
// これらを両方満たす `r` を返す (`f` が単調なら `f(op[l, r)) = true` となる最大の `r`)
// O(logN) time
template <typename F>
int max_right(int l, const F& f) const {
assert(0 <= l && l <= N);
assert(f(e()));
if (l == N) return N;
l += sz;
S s = e();
do {
while (!(l & 1)) l >>= 1;
if (!f(op(s, dat[l]))) {
while (l < sz) {
l <<= 1;
if (f(op(s, dat[l]))) s = op(s, dat[l++]);
}
return l - sz;
}
s = op(s, dat[l++]);
} while ((l & -l) != l);
return N;
}
// `l = r` または `f(op[l, r)) = true`
// `l = 0` または `f(op[l - 1, r)) = false`
// これらを両方満たす `l` を返す (`f` が単調なら `f(op[l, r)) = true` となる最小の `l`)
// O(logN) time
template <typename F>
int min_left(int r, const F &f) const {
assert(0 <= r && r <= N);
assert(f(e()));
if (r == 0) return 0;
r += sz;
S s = e();
do {
--r;
while (r > 1 && (r & 1)) r >>= 1;
if (!f(op(dat[r], s))) {
while (r < sz) {
r = r << 1 | 1;
if (f(op(dat[r], s))) s = op(dat[r--], s);
}
return r + 1 - sz;
}
s = op(dat[r], s);
} while ((r & -r) != r);
return 0;
}
};
};