$\newcommand{\O}{\mathrm{O}}$

簡潔データ構造の基本となる bit vector の実装コード(らしき何か).
時間計算量: rank $\O (1)$, select $\O( \log \log n)$
(※ 実装途中でやめちゃった...)
template <int SIZE> class BitVector
{
private:
struct block {
int ary_len;
int *one_pos;
short int *descendant_start_index, *subtree_count_sum;
block() : ary_len(0), one_pos(nullptr), descendant_start_index(nullptr), subtree_count_sum(nullptr){}
inline bool isSmall(){
return ary_len;
}
};
static constexpr unsigned int BIT_SIZE = SIZE;
static constexpr unsigned int UNIT_SIZE = 32;
static constexpr unsigned int SMALL_BLOCK_SIZE = ceil(log2(BIT_SIZE)/2.0);
static constexpr unsigned int BLOCK_SIZE = 4*SMALL_BLOCK_SIZE*SMALL_BLOCK_SIZE;
static constexpr unsigned int BIT_REST = BIT_SIZE%BLOCK_SIZE;
static constexpr unsigned int BLOCK_COUNT = BIT_SIZE/BLOCK_SIZE+1;
static constexpr unsigned int SMALL_BLOCK_COUNT = 4*SMALL_BLOCK_SIZE;
static constexpr unsigned int REAL_BIT_SIZE = BLOCK_COUNT*BLOCK_SIZE;
static constexpr unsigned int VECTOR_SIZE = REAL_BIT_SIZE/UNIT_SIZE+1;
static constexpr unsigned int ALL_SMALL_BLOCK_COUNT = BLOCK_COUNT*SMALL_BLOCK_COUNT;
static constexpr double BOUNDARY_INDEX = 3.5;
static constexpr unsigned int BOUNDARY_LENGTH = ceil(pow(log2(BIT_SIZE),BOUNDARY_INDEX));
static constexpr unsigned int DESCENDANT_COUNT = ceil(sqrt(log2(BIT_SIZE)));
unsigned int* bit_vector;
int* block_rank;
short int* small_block_rank;
char** lookup_table_rank;
block* block_select;
char** lookup_table_select;
int select_small_block_memory, select_large_block_memory;
public:
void alloc_memory() {
// bit は右から左(←)の向きであることに注意する
bit_vector = new unsigned int[VECTOR_SIZE];
std::cerr << "bit vector: " << VECTOR_SIZE*4 << "bytes\n";
// 本当は int(32bit) ではなく log(n)bit 長でなくてはならない
block_rank = new int[BLOCK_COUNT+1];
std::cerr << "block_rank: " << (BLOCK_COUNT+1)*4 << "bytes\n";
// 本当は int(32bit) ではなく log(log^2(n))bit 長でなくてはならない
small_block_rank = new short int[ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT];
std::cerr << "small_block_rank: " << ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT*2 << "bytes\n";
// 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない
lookup_table_rank = new char*[1 << SMALL_BLOCK_SIZE];
for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
lookup_table_rank[i] = new char[SMALL_BLOCK_SIZE+1];
}
// 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない
std::cerr << "lookup_table_table: " << (1 << SMALL_BLOCK_SIZE)*(SMALL_BLOCK_SIZE+1) << "bytes\n";
lookup_table_select = new char*[1 << SMALL_BLOCK_SIZE];
for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
lookup_table_select[i] = new char[SMALL_BLOCK_SIZE];
}
std::cerr << "lookup_table_select: " << (1 << SMALL_BLOCK_SIZE)*SMALL_BLOCK_SIZE << "bytes\n";
}
void init_() {
memset(bit_vector, 0, sizeof(unsigned int)*VECTOR_SIZE);
}
BitVector() : select_small_block_memory(0), select_large_block_memory(0){
alloc_memory();
init_();
}
inline void set(int index) {
bit_vector[index/UNIT_SIZE] |= (1u << index % UNIT_SIZE);
}
inline int isSet(int index) {
return (bit_vector[index/UNIT_SIZE] >> index % UNIT_SIZE) & 1u;
}
inline int bit_mask(int start, int end) {
int left_id = start / UNIT_SIZE, right_id = (end - 1) / UNIT_SIZE;
int left_pos = start % UNIT_SIZE, right_pos = (end - 1) % UNIT_SIZE + 1;
return (left_id == right_id)
? ((bit_vector[left_id] >> left_pos) & ((1u << (right_pos-left_pos))-1u))
: ((bit_vector[left_id] >> left_pos) | ((bit_vector[right_id] & ((1u << right_pos)-1u)) << (UNIT_SIZE - left_pos)));
}
void build_rank_block() {
int bit_index = 0, small_block_index = 0;
block_rank[0] = 0;
for(int i = 0; i < BLOCK_COUNT; i++){
block_rank[i+1] = block_rank[i];
small_block_rank[small_block_index++] = 0;
for(int j = 0; j < SMALL_BLOCK_COUNT; j++){
small_block_rank[small_block_index] = small_block_rank[small_block_index-1];
for(int k = 0; k < SMALL_BLOCK_SIZE; k++){
small_block_rank[small_block_index] += isSet(bit_index++);
}
small_block_index++;
}
block_rank[i+1] += small_block_rank[small_block_index-1];
}
}
void build_rank_lookup_table() {
for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
lookup_table_rank[i][0] = 0;
for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
lookup_table_rank[i][j+1] = lookup_table_rank[i][j] + ((i >> j) & 1);
}
}
}
void build_logn_ary_rec(short int* st_count, short int* des_index, int left, int right) {
if(right - left == 1) return;
int cur_len = (right - left) / DESCENDANT_COUNT + 1;
int cur_index = left, nx_index = right;
for(int i = 0; i < cur_len-1; i++){
des_index[nx_index] = cur_index;
st_count[nx_index] = st_count[cur_index += DESCENDANT_COUNT-1];
nx_index++;
}
des_index[nx_index] = cur_index;
st_count[nx_index] = st_count[right-1];
build_logn_ary_rec(st_count, des_index, right, nx_index+1);
}
// [left, right)
void build_select_logn_ary(block* bl, int left, int right) {
short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index;
int leaf_count = (right - left) / SMALL_BLOCK_SIZE + 1;
int cur_len = leaf_count, node_len = cur_len;
while(cur_len > 1){
cur_len = cur_len / DESCENDANT_COUNT + 1;
node_len += cur_len;
}
bl->ary_len = node_len;
st_count = new short int[node_len];
des_index = new short int[node_len];
select_small_block_memory += 4 * node_len;
memset(st_count, 0, sizeof(short int)*node_len);
memset(des_index, -1, sizeof(short int)*node_len);
int bit_index = left;
for(int i = 0; i < leaf_count - 1; i++){
for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
st_count[i] += isSet(bit_index++);
}
st_count[i+1] += st_count[i];
}
// 余りを計算
for(;bit_index < right; bit_index++) {
st_count[leaf_count-1] += isSet(bit_index++);
}
build_logn_ary_rec(st_count, des_index, 0, leaf_count);
}
void build_select_data_structure() {
int bit_length = 0, bit_sum = 0, block_index = 0;
int block_select_count = block_rank[BLOCK_COUNT] / BLOCK_SIZE + 1;
block_select = new block[block_select_count];
std::cerr << "block_pointer: " << block_select_count*4 << "bytes\n";
int* temp_pos = new int[BLOCK_SIZE];
select_large_block_memory += 4 * BLOCK_SIZE;
for(int i = 0; i < BIT_SIZE; i++){
bit_length++;
if(isSet(i)){
temp_pos[bit_sum] = i;
bit_sum++;
if(bit_sum == BLOCK_SIZE){
if(bit_length <= BOUNDARY_LENGTH){
build_select_logn_ary(&block_select[block_index], i+1-bit_length, i+1);
}else{
block_select[block_index].one_pos = temp_pos;
temp_pos = new int[BLOCK_SIZE];
select_large_block_memory += 4 * BLOCK_SIZE;
}
bit_length = bit_sum = 0;
block_index++;
}
}
}
}
void build_select_lookup_table() {
for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
int num = 0;
for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
if((i >> j)&1) lookup_table_select[i][num++] = j;
}
}
}
void build() {
build_rank_block();
build_rank_lookup_table();
build_select_data_structure();
build_select_lookup_table();
std::cerr << "SELECT_SMALL_BLOCK: " << select_small_block_memory << "bytes\n";
std::cerr << "SELECT_LARGE_BLOCK: " << select_large_block_memory << "bytes\n";
}
inline int rank(int x) {
return block_rank[x / BLOCK_SIZE] + small_block_rank[x / BLOCK_SIZE * (SMALL_BLOCK_COUNT+1) + x / SMALL_BLOCK_SIZE]
+ lookup_table_rank[bit_mask(x - x % SMALL_BLOCK_SIZE, x)][x % SMALL_BLOCK_SIZE];
}
int select_search(block* bl, const int start, const int target_pos, int cur_index) {
if(bl[index]->descendant_start_index < 0){
return lookup_table_select[bit_mask(start + cur_index*SMALL_BLOCK_SIZE, start + (cur_index+1)*SMALL_BLOCK_SIZE)][target_pos];
}
short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index;
int low = des_index[cur_index], high = low + DESCENDANT_COUNT - 1;
if(target_pos <= st_count[low]) return select_search(bl, start, target_pos, low);
while(high - low > 1){
int mid = (low + high) / 2;
if(target_pos > st_count[mid]){
low = mid;
}else{
high = mid;
}
}
return select_search(bl, start, target_pos, high);
}
// x番目の1の位置を返す(0_indexed)
inline int select(int x) {
return block_select[x / BLOCK_SIZE]->isSmall()
? select_search(block_select[x / BLOCK_SIZE], x / BLOCK_SIZE * BLOCK_SIZE, x % BLOCK_SIZE + 1, block_select[x / BLOCK_SIZE]->ary_len)
: block_select[x / BLOCK_SIZE]->one_pos[x % BLOCK_SIZE];
}
};
verify していません(verify 問題を知らない)