$\newcommand{\O}{\mathrm{O}}$
簡潔データ構造の基本となる bit vector の実装コード(らしき何か).
時間計算量: rank $\O (1)$, select $\O( \log \log n)$
(※ 実装途中でやめちゃった...)
template <int SIZE> class BitVector { private: struct block { int ary_len; int *one_pos; short int *descendant_start_index, *subtree_count_sum; block() : ary_len(0), one_pos(nullptr), descendant_start_index(nullptr), subtree_count_sum(nullptr){} inline bool isSmall(){ return ary_len; } }; static constexpr unsigned int BIT_SIZE = SIZE; static constexpr unsigned int UNIT_SIZE = 32; static constexpr unsigned int SMALL_BLOCK_SIZE = ceil(log2(BIT_SIZE)/2.0); static constexpr unsigned int BLOCK_SIZE = 4*SMALL_BLOCK_SIZE*SMALL_BLOCK_SIZE; static constexpr unsigned int BIT_REST = BIT_SIZE%BLOCK_SIZE; static constexpr unsigned int BLOCK_COUNT = BIT_SIZE/BLOCK_SIZE+1; static constexpr unsigned int SMALL_BLOCK_COUNT = 4*SMALL_BLOCK_SIZE; static constexpr unsigned int REAL_BIT_SIZE = BLOCK_COUNT*BLOCK_SIZE; static constexpr unsigned int VECTOR_SIZE = REAL_BIT_SIZE/UNIT_SIZE+1; static constexpr unsigned int ALL_SMALL_BLOCK_COUNT = BLOCK_COUNT*SMALL_BLOCK_COUNT; static constexpr double BOUNDARY_INDEX = 3.5; static constexpr unsigned int BOUNDARY_LENGTH = ceil(pow(log2(BIT_SIZE),BOUNDARY_INDEX)); static constexpr unsigned int DESCENDANT_COUNT = ceil(sqrt(log2(BIT_SIZE))); unsigned int* bit_vector; int* block_rank; short int* small_block_rank; char** lookup_table_rank; block* block_select; char** lookup_table_select; int select_small_block_memory, select_large_block_memory; public: void alloc_memory() { // bit は右から左(←)の向きであることに注意する bit_vector = new unsigned int[VECTOR_SIZE]; std::cerr << "bit vector: " << VECTOR_SIZE*4 << "bytes\n"; // 本当は int(32bit) ではなく log(n)bit 長でなくてはならない block_rank = new int[BLOCK_COUNT+1]; std::cerr << "block_rank: " << (BLOCK_COUNT+1)*4 << "bytes\n"; // 本当は int(32bit) ではなく log(log^2(n))bit 長でなくてはならない small_block_rank = new short int[ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT]; std::cerr << "small_block_rank: " << ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT*2 << "bytes\n"; // 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない lookup_table_rank = new char*[1 << SMALL_BLOCK_SIZE]; for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){ lookup_table_rank[i] = new char[SMALL_BLOCK_SIZE+1]; } // 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない std::cerr << "lookup_table_table: " << (1 << SMALL_BLOCK_SIZE)*(SMALL_BLOCK_SIZE+1) << "bytes\n"; lookup_table_select = new char*[1 << SMALL_BLOCK_SIZE]; for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){ lookup_table_select[i] = new char[SMALL_BLOCK_SIZE]; } std::cerr << "lookup_table_select: " << (1 << SMALL_BLOCK_SIZE)*SMALL_BLOCK_SIZE << "bytes\n"; } void init_() { memset(bit_vector, 0, sizeof(unsigned int)*VECTOR_SIZE); } BitVector() : select_small_block_memory(0), select_large_block_memory(0){ alloc_memory(); init_(); } inline void set(int index) { bit_vector[index/UNIT_SIZE] |= (1u << index % UNIT_SIZE); } inline int isSet(int index) { return (bit_vector[index/UNIT_SIZE] >> index % UNIT_SIZE) & 1u; } inline int bit_mask(int start, int end) { int left_id = start / UNIT_SIZE, right_id = (end - 1) / UNIT_SIZE; int left_pos = start % UNIT_SIZE, right_pos = (end - 1) % UNIT_SIZE + 1; return (left_id == right_id) ? ((bit_vector[left_id] >> left_pos) & ((1u << (right_pos-left_pos))-1u)) : ((bit_vector[left_id] >> left_pos) | ((bit_vector[right_id] & ((1u << right_pos)-1u)) << (UNIT_SIZE - left_pos))); } void build_rank_block() { int bit_index = 0, small_block_index = 0; block_rank[0] = 0; for(int i = 0; i < BLOCK_COUNT; i++){ block_rank[i+1] = block_rank[i]; small_block_rank[small_block_index++] = 0; for(int j = 0; j < SMALL_BLOCK_COUNT; j++){ small_block_rank[small_block_index] = small_block_rank[small_block_index-1]; for(int k = 0; k < SMALL_BLOCK_SIZE; k++){ small_block_rank[small_block_index] += isSet(bit_index++); } small_block_index++; } block_rank[i+1] += small_block_rank[small_block_index-1]; } } void build_rank_lookup_table() { for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){ lookup_table_rank[i][0] = 0; for(int j = 0; j < SMALL_BLOCK_SIZE; j++){ lookup_table_rank[i][j+1] = lookup_table_rank[i][j] + ((i >> j) & 1); } } } void build_logn_ary_rec(short int* st_count, short int* des_index, int left, int right) { if(right - left == 1) return; int cur_len = (right - left) / DESCENDANT_COUNT + 1; int cur_index = left, nx_index = right; for(int i = 0; i < cur_len-1; i++){ des_index[nx_index] = cur_index; st_count[nx_index] = st_count[cur_index += DESCENDANT_COUNT-1]; nx_index++; } des_index[nx_index] = cur_index; st_count[nx_index] = st_count[right-1]; build_logn_ary_rec(st_count, des_index, right, nx_index+1); } // [left, right) void build_select_logn_ary(block* bl, int left, int right) { short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index; int leaf_count = (right - left) / SMALL_BLOCK_SIZE + 1; int cur_len = leaf_count, node_len = cur_len; while(cur_len > 1){ cur_len = cur_len / DESCENDANT_COUNT + 1; node_len += cur_len; } bl->ary_len = node_len; st_count = new short int[node_len]; des_index = new short int[node_len]; select_small_block_memory += 4 * node_len; memset(st_count, 0, sizeof(short int)*node_len); memset(des_index, -1, sizeof(short int)*node_len); int bit_index = left; for(int i = 0; i < leaf_count - 1; i++){ for(int j = 0; j < SMALL_BLOCK_SIZE; j++){ st_count[i] += isSet(bit_index++); } st_count[i+1] += st_count[i]; } // 余りを計算 for(;bit_index < right; bit_index++) { st_count[leaf_count-1] += isSet(bit_index++); } build_logn_ary_rec(st_count, des_index, 0, leaf_count); } void build_select_data_structure() { int bit_length = 0, bit_sum = 0, block_index = 0; int block_select_count = block_rank[BLOCK_COUNT] / BLOCK_SIZE + 1; block_select = new block[block_select_count]; std::cerr << "block_pointer: " << block_select_count*4 << "bytes\n"; int* temp_pos = new int[BLOCK_SIZE]; select_large_block_memory += 4 * BLOCK_SIZE; for(int i = 0; i < BIT_SIZE; i++){ bit_length++; if(isSet(i)){ temp_pos[bit_sum] = i; bit_sum++; if(bit_sum == BLOCK_SIZE){ if(bit_length <= BOUNDARY_LENGTH){ build_select_logn_ary(&block_select[block_index], i+1-bit_length, i+1); }else{ block_select[block_index].one_pos = temp_pos; temp_pos = new int[BLOCK_SIZE]; select_large_block_memory += 4 * BLOCK_SIZE; } bit_length = bit_sum = 0; block_index++; } } } } void build_select_lookup_table() { for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){ int num = 0; for(int j = 0; j < SMALL_BLOCK_SIZE; j++){ if((i >> j)&1) lookup_table_select[i][num++] = j; } } } void build() { build_rank_block(); build_rank_lookup_table(); build_select_data_structure(); build_select_lookup_table(); std::cerr << "SELECT_SMALL_BLOCK: " << select_small_block_memory << "bytes\n"; std::cerr << "SELECT_LARGE_BLOCK: " << select_large_block_memory << "bytes\n"; } inline int rank(int x) { return block_rank[x / BLOCK_SIZE] + small_block_rank[x / BLOCK_SIZE * (SMALL_BLOCK_COUNT+1) + x / SMALL_BLOCK_SIZE] + lookup_table_rank[bit_mask(x - x % SMALL_BLOCK_SIZE, x)][x % SMALL_BLOCK_SIZE]; } int select_search(block* bl, const int start, const int target_pos, int cur_index) { if(bl[index]->descendant_start_index < 0){ return lookup_table_select[bit_mask(start + cur_index*SMALL_BLOCK_SIZE, start + (cur_index+1)*SMALL_BLOCK_SIZE)][target_pos]; } short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index; int low = des_index[cur_index], high = low + DESCENDANT_COUNT - 1; if(target_pos <= st_count[low]) return select_search(bl, start, target_pos, low); while(high - low > 1){ int mid = (low + high) / 2; if(target_pos > st_count[mid]){ low = mid; }else{ high = mid; } } return select_search(bl, start, target_pos, high); } // x番目の1の位置を返す(0_indexed) inline int select(int x) { return block_select[x / BLOCK_SIZE]->isSmall() ? select_search(block_select[x / BLOCK_SIZE], x / BLOCK_SIZE * BLOCK_SIZE, x % BLOCK_SIZE + 1, block_select[x / BLOCK_SIZE]->ary_len) : block_select[x / BLOCK_SIZE]->one_pos[x % BLOCK_SIZE]; } };
verify していません(verify 問題を知らない)