cp-STL

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ras-cp/cp-STL

:heavy_check_mark: Wavelet Matrix
(cpstl/ds/WaveletMatrix.hpp)

Depends on

Verified with

Code

#pragma once

#include "cpstl/other/Template.hpp"
#include "cpstl/ds/BitVector.hpp"

namespace cpstd {

// @brief Wavelet Matrix
// @see https://miti-7.hatenablog.com/entry/2018/04/28/152259
// @see https://takeda25.hatenablog.jp/entry/20130303/1362301095
// @see https://drken1215.hatenablog.com/entry/2023/10/19/220215

template <typename T>
class WaveletMatrix {
	private:
	static_assert(std::is_integral_v<T>, "template parameter T must be integral type");
	std::vector<cpstd::BitVector> bv;
	std::vector<std::vector<T>> sum;
	std::vector<T> dat;
	int N, lgm;

	public:
	WaveletMatrix() {}
	explicit WaveletMatrix(int _N) : N(_N), dat(_N, 0) {}
	explicit WaveletMatrix(const std::vector<T> &v) : N((int)v.size()), dat(v) { build(); }

	// [pos] ← x で更新
	// O(1) time
	void set(int pos, T &x) {
		assert(0 <= pos && pos < N && x >= 0);
		dat[pos] = x;
	}

	// データ構造を構築する
	// O(NlogM) time (M = max [i])
	void build() {
		T maxi = 1;
		for (auto val : dat) maxi = std::max(maxi, val);
		lgm = std::bit_width((unsigned int)(maxi));
		std::vector<int> left(N), right(N), ord(N);
		std::iota(ord.begin(), ord.end(), 0);
		bv.assign(lgm, BitVector(N));
		sum.assign(lgm + 1, std::vector<T>(N + 1, 0));
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l = 0, r = 0;
			for (int i = 0; i < N; ++i) {
				if ((dat[ord[i]] >> pos) & 1) {
					bv[pos].set(i);
					right[r++] = ord[i];
				}
				else left[l++] = ord[i];
			}
			bv[pos].build();
			ord.swap(left);
			for (int i = 0; i < r; ++i) ord[i + l] = right[i];
			for (int i = 0; i < N; ++i) sum[pos][i + 1] = sum[pos][i] + dat[ord[i]];
		}
	}

	// [pos] を返す
	// O(logM) time
	T get(int pos) const {
		assert(0 <= pos && pos < N);
		return operator[](pos);
	}

	// [pos] を返す (assert なし)
	// O(logN) time
	T operator[](int pos) const noexcept {
		T res = 0;
		for (int b = lgm - 1; b >= 0; --b) {
			int z = bv[b].rank_0(pos);
			if (bv[b].get(pos)) {
				pos += bv[b].rank_0() - z;
				res |= T(1) << b;
			}
			else pos = z;
		}
		return res;
	}

	// [l, r) に含まれる x の個数を返す
	// O(logN) time
	int freq(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if ((x >> pos) & 1) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
			}
			else l = l0, r = r0;
		}
		return r - l;
	}

	// [l, r) に含まれる ub 未満の値の個数を返す
	// O(logM) time
	int range_freq(int l, int r, T ub) const {
		assert(0 <= l && l <= r && r <= N);
		int res = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if ((ub >> pos) & 1) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				res += r0 - l0;
			}
			else l = l0, r = r0;
		}
		return res;
	}

	// [l, r) に含まれる lb 以上 ub 未満の値の個数を返す
	// O(logM) time
	int range_freq(int l, int r, T lb, T ub) const { return range_freq(l, r, ub) - range_freq(l, r, lb); }

	// [l, r) をソートしたときの k 番目の要素 (0-indexed) を返す
	// O(logM) time
	T quantile(int l, int r, int k) const {
		assert(0 <= l && l <= r && r <= N);
		T res = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if (r0 - l0 <= k) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				k -= r0 - l0;
				res |= T(1) << pos;
			}
			else l = l0, r = r0;
		}
		return res;
	}

	// [l, r) の降順上位 k 個の積を返す
	T top_k_fold(int l, int r, int k) const {
		assert(0 <= l && l <= r && r <= N);
		if (l == r) return 0;
		T res = 0, val = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if (r0 - l0 <= k) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				k -= r0 - l0;
				val = T(1) << pos;
				res += sum[pos][r0] - sum[pos][l0];
			}
			else l = l0, r = r0;
		}
		res += val * k;
		return res;
	}

	// [l, r) に含まれる x 未満最大の要素 (存在しなければ -1) を返す
	// O(logM) time
	T pred(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		int f = range_freq(l, r, x);
		return f ? quantile(l, r, f - 1) : T(-1);
	}

	// [l, r) に含まれる x 以上最小の要素 (存在しなければ -1) を返す
	// O(logM) time
	T succ(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		int f = range_freq(l, r, x);
		return f != r - l ? quantile(l, r, f) : T(-1);
	}
};
};
#line 2 "cpstl/ds/WaveletMatrix.hpp"

#line 2 "cpstl/other/Template.hpp"

#include <immintrin.h>
#include <algorithm>
#include <array>
#include <bit>
#include <bitset>
#include <cassert>
#include <cctype>
#include <cfenv>
#include <charconv>
#include <chrono>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <complex>
#include <cstdarg>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <initializer_list>
#include <iomanip>
#include <ios>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <streambuf>
#include <string>
#include <tuple>
#include <type_traits>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#line 2 "cpstl/math/StaticModint.hpp"

#line 4 "cpstl/math/StaticModint.hpp"

namespace cpstd {

// @brief Static Modint

// @see https://hackmd.io/@tatyam-prime/rkVCOcwQn

template <uint32_t m>
struct StaticModint {
	private:
	using mint = StaticModint;
	uint32_t _v = 0;

	static constexpr bool is_prime = []() -> bool {
		if (m == 1) return false;
		if (m == 2 || m == 7 || m == 61) return true;
		if (!(m & 1)) return false;
		uint32_t d = m - 1;
		while (!(d & 1)) d >>= 1;
		for (uint32_t a : {2, 7, 61}) {
			uint32_t t = d;
			mint y = mint(a).pow(t);
			while (t != m - 1 && y != 1 && y != m - 1) {
				y *= y;
				t <<= 1;
			}
			if (y != m - 1 && !(t & 1)) return false;
		}
		return true;
	}();
	
	static constexpr std::pair<int32_t, int32_t> inv_gcd(int32_t a, int32_t b) {
		if (a == 0) return {b, 0};
		int32_t s = b, t = a, m0 = 0, m1 = 1;
		while (t) {
			const int32_t q = s / t;
			s -= t * q, std::swap(s, t);
			m0 -= m1 * q, std::swap(m0, m1);
		}
		if (m0 < 0) m0 += b / s;
		return {s, m0};
	}

	public:
	constexpr StaticModint() {}
	template <typename T>
	constexpr StaticModint(T v) {
		static_assert(std::is_integral_v<T>, "T is not integral type.");
		if constexpr (std::is_signed_v<T>) {
			int64_t x = int64_t(v % int64_t(m));
			if (x < 0) x += m;
			_v = uint32_t(x);
		}
		else _v = uint32_t(v % m);
	}

	static constexpr mint raw(uint32_t v) { mint a; a._v = v; return a; }

	static constexpr uint32_t mod() { return m; }

	constexpr uint32_t val() const { return _v; }

	constexpr mint& operator++() { return *this += 1; }

	constexpr mint operator++(int) { mint res = *this; ++*this; return res; }

	constexpr mint& operator--() { return *this -= 1; }

	constexpr mint operator--(int) { mint res = *this; --*this; return res; }

	constexpr mint& operator+=(mint rhs) {
		if (_v >= m - rhs._v) _v -= m;
		_v += rhs._v;
		return *this;
	}

	constexpr mint& operator-=(mint rhs) {
		if (_v < rhs._v) _v += m;
		_v -= rhs._v;
		return *this;
	}

	constexpr mint& operator*=(mint rhs) { return *this = *this * rhs; }

	constexpr mint& operator/=(mint rhs) { return *this *= rhs.inv(); }

	constexpr mint operator+() const { return *this; }

	constexpr mint operator-() const { return mint{} - *this; }

	constexpr mint pow(long long n) const {
		assert(0 <= n);
		if (n == 0) return 1;
		mint x = *this, r = 1;
		while (n > 0) {
			if (n & 1) r *= x;
			x *= x;
			n >>= 1;
			if (!n) return r;
		}
		return r;
	}

	constexpr mint inv() const {
		if constexpr (is_prime) {
			assert(_v);
			return pow(m - 2);
		}
		else {
			auto eg = inv_gcd(_v, m);
			assert(eg.first == 1);
			return eg.second;
		}
	}

	friend constexpr mint operator+(mint lhs, mint rhs) { return lhs += rhs; }

	friend constexpr mint operator-(mint lhs, mint rhs) { return lhs -= rhs; }

	friend constexpr mint operator*(mint lhs, mint rhs) { return uint64_t(lhs._v) * rhs._v; }
	
	friend constexpr mint operator/(mint lhs, mint rhs) { return lhs /= rhs; }

	friend constexpr bool operator==(mint lhs, mint rhs) { return lhs._v == rhs._v; }

	friend constexpr bool operator!=(mint lhs, mint rhs) { return lhs._v != rhs._v; }
};

using Modint998244353 = StaticModint<998244353>;

constexpr Modint998244353 operator""_M(unsigned long long x) { return x; }
};
#line 2 "cpstl/other/Fastio.hpp"

#line 4 "cpstl/other/Fastio.hpp"

namespace cpstd {

// @brief Fast I/O
// @see https://judge.yosupo.jp/submission/21623
// @see https://maspypy.com/library-checker-many-a-b

namespace Fastio {

static constexpr const uint32_t BUF_SIZE = 1 << 17;
char ibuf[BUF_SIZE], obuf[BUF_SIZE], out[100];
uint32_t pil = 0, pir = 0, por = 0;

struct Pre {
	char num[10000][4];

	constexpr Pre() : num() {
		for (int i = 0; i < 10000; ++i) {
			int n = i;
			for (int j = 3; j >= 0; --j) {
				num[i][j] = n % 10 | '0';
				n /= 10;
			}
		}
	}
} constexpr pre;

inline void load() {
	std::memcpy(ibuf, ibuf + pil, pir - pil);
	pir = pir - pil + std::fread(ibuf + pir - pil, 1, BUF_SIZE - pir + pil, stdin);
	pil = 0;
	if (pir < BUF_SIZE) ibuf[pir++] = '\n';
}

inline void flush() {
	fwrite(obuf, 1, por, stdout);
	por = 0;
}

void _input(char &dest) {
	do {
		if (pil + 1 > pir) load();
		dest = ibuf[pil++];
	} while (std::isspace(dest));
}

void _input(std::string &dest) {
	dest.clear();
	char c;
	do {
		if (pil + 1 > pir) load();
		c = ibuf[pil++];
	} while (std::isspace(c));
	do {
		dest += c;
		if (pil == pir) load();
		c = ibuf[pil++];
	} while (!std::isspace(c));
}

void _input(float &dest) {
	std::string s;
	_input(s);
	dest = std::stof(s);
}

void _input(double &dest) {
	std::string s;
	_input(s);
	dest = std::stod(s);
}

void _input(long double &dest) {
	std::string s;
	_input(s);
	dest = std::stold(s);
}

template <typename T>
void input_int(T &x) {
	if (pil + 100 > pir) load();
	char c;
	do {
		c = ibuf[pil++];
	} while (c < '-');
	bool minus = 0;
	if constexpr (std::is_signed<T>::value || std::is_same_v<T, __int128_t>) {
		if (c == '-') minus = 1, c = ibuf[pil++];
	}
	x = 0;
	while (c >= '0') x = x * 10 + (c & 15), c = ibuf[pil++];
	if constexpr (std::is_signed<T>::value || std::is_same_v<T, __int128_t>) {
		if (minus) x = -x;
	}
}

void _input(int &dest) { input_int(dest); }
void _input(unsigned int &dest) { input_int(dest); }
void _input(long long &dest) { input_int(dest); }
void _input(unsigned long long &dest) { input_int(dest); }
void _input(__int128 &dest) { input_int(dest); }
void _input(unsigned __int128 &dest) { input_int(dest); }

template <uint32_t m>
void _input(cpstd::StaticModint<m> &dest) { long long a; _input(a); dest = a; }

template <typename T, typename U>
void _input(std::pair<T, U> &dest) { _input(dest.first), _input(dest.second); }

template <std::size_t N = 0, typename T>
void input_tuple(T &t) {
	if constexpr (N < std::tuple_size<T>::value) {
		auto &x = std::get<N>(t);
		input(x);
		input_tuple<N + 1>(t);
	}
}

template <typename... T>
void _input(std::tuple<T...> &dest) { input_tuple(dest); }

template <std::size_t N = 0, typename T>
void _input(std::array<T, N> &dest) { for (auto &e : dest) _input(e); }

template <typename T>
void _input(std::vector<T> &dest) { for (auto &e : dest) _input(e); }

void input() {}

// 各引数に入力
template <typename H, typename... T>
void input(H &desth, T &... destt) { _input(desth), input(destt...); }

void _print(const char tg) {
	if (por == BUF_SIZE) flush();
	obuf[por++] = tg;
}

void _print(const std::string tg) { for (char c : tg) _print(c); }

void _print(const char *tg) {
	std::size_t len = std::strlen(tg);
	for (std::size_t i = 0; i < len; ++i) _print(tg[i]);
}

template <typename T>
void print_int(T x) {
	if (por > BUF_SIZE - 100) flush();
	if (x < 0) obuf[por++] = '-', x = -x;
	int outi;
	for (outi = 96; x >= 10000; outi -= 4) {
		std::memcpy(out + outi, pre.num[x % 10000], 4);
		x /= 10000;
	}
	if (x >= 1000) {
		std::memcpy(obuf + por, pre.num[x], 4);
		por += 4;
	}
	else if (x >= 100) {
		std::memcpy(obuf + por, pre.num[x] + 1, 3);
		por += 3;
	}
	else if (x >= 10) {
		int q = (x * 103) >> 10;
		obuf[por] = q | '0';
		obuf[por + 1] = (x - q * 10) | '0';
		por += 2;
	}
	else obuf[por++] = x | '0';
	std::memcpy(obuf + por, out + outi + 4, 96 - outi);
	por += 96 - outi;
}

template <typename T>
void print_real(T tg) {
	std::ostringstream oss;
	oss << std::fixed << std::setprecision(15) << double(tg);
	std::string s = oss.str();
	_print(s);
}

void _print(int tg) { print_int(tg); }
void _print(unsigned int tg) { print_int(tg); }
void _print(long long tg) { print_int(tg); }
void _print(unsigned long long tg) { print_int(tg); }
void _print(__int128 tg) { print_int(tg); }
void _print(unsigned __int128 tg) { print_int(tg); }
void _print(float tg) { print_real(tg); }
void _print(double tg) { print_real(tg); }
void _print(long double tg) { print_real(tg); }

template <uint32_t m>
void _print(cpstd::StaticModint<m> tg) { print_int(tg.val()); }

template <typename T, typename U>
void _print(const std::pair<T, U> tg) {
	_print(tg.first);
	_print(' ');
	_print(tg.second);
}

template <std::size_t N = 0, typename T>
void print_tuple(const T tg) {
	if constexpr (N < std::tuple_size<T>::value) {
		if constexpr (N > 0) _print(' ');
		const auto x = std::get<N>(tg);
		_print(x);
		print_tuple<N + 1>(tg);
	}
}

template <typename... T>
void _print(std::tuple<T...> tg) { print_tuple(tg); }

template <typename T, std::size_t N>
void _print(const std::array<T, N> tg) {
	auto len = tg.size();
	for (std::size_t i = 0; i < len; ++i) {
		if (i) _print(' ');
		_print(tg[i]);
	}
}

template <typename T>
void _print(const std::vector<T> tg) {
	auto  len = tg.size();
	for (std::size_t i = 0; i < len; ++i) {
		if (i) _print(' ');
		_print(tg[i]);
	}
}

void print() { _print('\n'); }

// 各引数を空白区切りで出力し改行
template <typename H, typename... T>
void print(H &&tgh, T &&... tgt) {
	_print(tgh);
	if (sizeof...(tgt)) _print(' ');
	print(std::forward<T>(tgt)...);
}

void __attribute__((destructor)) _d() { flush(); }

};

using Fastio::input;
using Fastio::print;
using Fastio::flush;

};
#line 2 "cpstl/ds/BitVector.hpp"

#line 4 "cpstl/ds/BitVector.hpp"

namespace cpstd {

// @brief Succint Bit Vector

struct BitVector {
	private:
	using u32 = unsigned int;
	using u64 = unsigned long long;
	std::vector<u64> block;
	std::vector<u32> cnt;
	u32 N, zero, sz;

	public:
	BitVector() {}
	explicit BitVector(u32 _N) : N(_N), sz(((_N + 1) >> 6) + 1) {
		block.resize(sz);
		cnt.resize(sz);
	}

	// [pos] を x の i-th bit で更新
	// O(1) time
	void set(u32 i, u64 x = 1ULL) {
		assert((i >> 6) < sz);
		block[i >> 6] |= (x << (i & 63));
	}

	// データ構造を構築する
	// O(N) time
	void build() {
		for (u32 i = 1; i < sz; ++i) cnt[i] = cnt[i - 1] + std::popcount(block[i - 1]);
		zero = rank_0(N);
	}

	// [0, i) に含まれる 1 の数を返す
	// O(1) time
	u32 rank_1(u32 i) const {
		assert((i >> 6) < sz);
		return cnt[i >> 6] + std::popcount(block[i >> 6] & ((1ULL << (i & 63)) - 1ULL));
	}

	// [l, r) に含まれる 1 の数を返す
	// O(1) time
	u32 rank_1(u32 l, u32 r) const { return rank_1(r) - rank_1(l); }

	// [0, i) に含まれる 0 の数を返す
	// O(1) time
	u32 rank_0(u32 i) const { return i - rank_1(i); }

	// [l, r) に含まれる 0 の数を返す
	// O(1) time
	u32 rank_0(u32 l, u32 r) const { return rank_0(r) - rank_0(l); }

	// [0, N) に含まれる 0 の数を返す
	// O(1) time
	u32 rank_0() const { return zero; }

	// [pos] を返す
	// O(1) time
	u32 get(u32 i) const {
		assert((i >> 6) < sz);
		return (u32)(block[i >> 6] >> (i & 63)) & 1;
	}

	// [pos] を返す (assert なし)
	// O(1) time
	u32 operator[](u32 i) const noexcept { return (u32)(block[i >> 6] >> (i & 63)) & 1; }
};

};
#line 5 "cpstl/ds/WaveletMatrix.hpp"

namespace cpstd {

// @brief Wavelet Matrix
// @see https://miti-7.hatenablog.com/entry/2018/04/28/152259
// @see https://takeda25.hatenablog.jp/entry/20130303/1362301095
// @see https://drken1215.hatenablog.com/entry/2023/10/19/220215

template <typename T>
class WaveletMatrix {
	private:
	static_assert(std::is_integral_v<T>, "template parameter T must be integral type");
	std::vector<cpstd::BitVector> bv;
	std::vector<std::vector<T>> sum;
	std::vector<T> dat;
	int N, lgm;

	public:
	WaveletMatrix() {}
	explicit WaveletMatrix(int _N) : N(_N), dat(_N, 0) {}
	explicit WaveletMatrix(const std::vector<T> &v) : N((int)v.size()), dat(v) { build(); }

	// [pos] ← x で更新
	// O(1) time
	void set(int pos, T &x) {
		assert(0 <= pos && pos < N && x >= 0);
		dat[pos] = x;
	}

	// データ構造を構築する
	// O(NlogM) time (M = max [i])
	void build() {
		T maxi = 1;
		for (auto val : dat) maxi = std::max(maxi, val);
		lgm = std::bit_width((unsigned int)(maxi));
		std::vector<int> left(N), right(N), ord(N);
		std::iota(ord.begin(), ord.end(), 0);
		bv.assign(lgm, BitVector(N));
		sum.assign(lgm + 1, std::vector<T>(N + 1, 0));
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l = 0, r = 0;
			for (int i = 0; i < N; ++i) {
				if ((dat[ord[i]] >> pos) & 1) {
					bv[pos].set(i);
					right[r++] = ord[i];
				}
				else left[l++] = ord[i];
			}
			bv[pos].build();
			ord.swap(left);
			for (int i = 0; i < r; ++i) ord[i + l] = right[i];
			for (int i = 0; i < N; ++i) sum[pos][i + 1] = sum[pos][i] + dat[ord[i]];
		}
	}

	// [pos] を返す
	// O(logM) time
	T get(int pos) const {
		assert(0 <= pos && pos < N);
		return operator[](pos);
	}

	// [pos] を返す (assert なし)
	// O(logN) time
	T operator[](int pos) const noexcept {
		T res = 0;
		for (int b = lgm - 1; b >= 0; --b) {
			int z = bv[b].rank_0(pos);
			if (bv[b].get(pos)) {
				pos += bv[b].rank_0() - z;
				res |= T(1) << b;
			}
			else pos = z;
		}
		return res;
	}

	// [l, r) に含まれる x の個数を返す
	// O(logN) time
	int freq(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if ((x >> pos) & 1) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
			}
			else l = l0, r = r0;
		}
		return r - l;
	}

	// [l, r) に含まれる ub 未満の値の個数を返す
	// O(logM) time
	int range_freq(int l, int r, T ub) const {
		assert(0 <= l && l <= r && r <= N);
		int res = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if ((ub >> pos) & 1) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				res += r0 - l0;
			}
			else l = l0, r = r0;
		}
		return res;
	}

	// [l, r) に含まれる lb 以上 ub 未満の値の個数を返す
	// O(logM) time
	int range_freq(int l, int r, T lb, T ub) const { return range_freq(l, r, ub) - range_freq(l, r, lb); }

	// [l, r) をソートしたときの k 番目の要素 (0-indexed) を返す
	// O(logM) time
	T quantile(int l, int r, int k) const {
		assert(0 <= l && l <= r && r <= N);
		T res = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if (r0 - l0 <= k) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				k -= r0 - l0;
				res |= T(1) << pos;
			}
			else l = l0, r = r0;
		}
		return res;
	}

	// [l, r) の降順上位 k 個の積を返す
	T top_k_fold(int l, int r, int k) const {
		assert(0 <= l && l <= r && r <= N);
		if (l == r) return 0;
		T res = 0, val = 0;
		for (int pos = lgm - 1; pos >= 0; --pos) {
			int l0 = bv[pos].rank_0(l), r0 = bv[pos].rank_0(r);
			if (r0 - l0 <= k) {
				l += bv[pos].rank_0() - l0;
				r += bv[pos].rank_0() - r0;
				k -= r0 - l0;
				val = T(1) << pos;
				res += sum[pos][r0] - sum[pos][l0];
			}
			else l = l0, r = r0;
		}
		res += val * k;
		return res;
	}

	// [l, r) に含まれる x 未満最大の要素 (存在しなければ -1) を返す
	// O(logM) time
	T pred(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		int f = range_freq(l, r, x);
		return f ? quantile(l, r, f - 1) : T(-1);
	}

	// [l, r) に含まれる x 以上最小の要素 (存在しなければ -1) を返す
	// O(logM) time
	T succ(int l, int r, T x) const {
		assert(0 <= l && l <= r && r <= N);
		int f = range_freq(l, r, x);
		return f != r - l ? quantile(l, r, f) : T(-1);
	}
};
};
Back to top page