$\newcommand{\O}{\mathrm{O}}$ My Algorithm : kopricky アルゴリズムライブラリ

kopricky アルゴリズムライブラリ

Bit Vector

コードについての説明(個人的メモ)

簡潔データ構造の基本となる bit vector の実装コード(らしき何か).

時間計算量: rank $\O (1)$, select $\O( \log \log n)$
(※ 実装途中でやめちゃった...)

コード

template <int SIZE> class BitVector
{
private:
    struct block {
        int ary_len;
        int *one_pos;
        short int *descendant_start_index, *subtree_count_sum;
        block() : ary_len(0), one_pos(nullptr), descendant_start_index(nullptr), subtree_count_sum(nullptr){}
        inline bool isSmall(){
            return ary_len;
        }
    };

    static constexpr unsigned int BIT_SIZE = SIZE;
    static constexpr unsigned int UNIT_SIZE = 32;
    static constexpr unsigned int SMALL_BLOCK_SIZE = ceil(log2(BIT_SIZE)/2.0);
    static constexpr unsigned int BLOCK_SIZE = 4*SMALL_BLOCK_SIZE*SMALL_BLOCK_SIZE;
    static constexpr unsigned int  BIT_REST = BIT_SIZE%BLOCK_SIZE;
    static constexpr unsigned int BLOCK_COUNT = BIT_SIZE/BLOCK_SIZE+1;
    static constexpr unsigned int SMALL_BLOCK_COUNT = 4*SMALL_BLOCK_SIZE;
    static constexpr unsigned int REAL_BIT_SIZE = BLOCK_COUNT*BLOCK_SIZE;
    static constexpr unsigned int VECTOR_SIZE = REAL_BIT_SIZE/UNIT_SIZE+1;
    static constexpr unsigned int ALL_SMALL_BLOCK_COUNT = BLOCK_COUNT*SMALL_BLOCK_COUNT;
    static constexpr double BOUNDARY_INDEX = 3.5;
    static constexpr unsigned int BOUNDARY_LENGTH = ceil(pow(log2(BIT_SIZE),BOUNDARY_INDEX));
    static constexpr unsigned int DESCENDANT_COUNT = ceil(sqrt(log2(BIT_SIZE)));

    unsigned int* bit_vector;
    int* block_rank;
    short int* small_block_rank;
    char** lookup_table_rank;
    block* block_select;
    char** lookup_table_select;
    int select_small_block_memory, select_large_block_memory;

public:

    void alloc_memory() {
        // bit は右から左(←)の向きであることに注意する
        bit_vector = new unsigned int[VECTOR_SIZE];
        std::cerr << "bit vector: " << VECTOR_SIZE*4 << "bytes\n";
        // 本当は int(32bit) ではなく log(n)bit 長でなくてはならない
        block_rank = new int[BLOCK_COUNT+1];
        std::cerr << "block_rank: " << (BLOCK_COUNT+1)*4 << "bytes\n";
        // 本当は int(32bit) ではなく log(log^2(n))bit 長でなくてはならない
        small_block_rank = new short int[ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT];
        std::cerr << "small_block_rank: " << ALL_SMALL_BLOCK_COUNT+BLOCK_COUNT*2 << "bytes\n";
        // 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない
        lookup_table_rank = new char*[1 << SMALL_BLOCK_SIZE];
        for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
            lookup_table_rank[i] = new char[SMALL_BLOCK_SIZE+1];
        }
        // 本当は char(32bit) ではなく log(1/2(log(n)))bit 長でなくてはならない
        std::cerr << "lookup_table_table: " << (1 << SMALL_BLOCK_SIZE)*(SMALL_BLOCK_SIZE+1) << "bytes\n";
        lookup_table_select = new char*[1 << SMALL_BLOCK_SIZE];
        for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
            lookup_table_select[i] = new char[SMALL_BLOCK_SIZE];
        }
        std::cerr << "lookup_table_select: " << (1 << SMALL_BLOCK_SIZE)*SMALL_BLOCK_SIZE << "bytes\n";
    }

    void init_() {
        memset(bit_vector, 0, sizeof(unsigned int)*VECTOR_SIZE);
    }

    BitVector() : select_small_block_memory(0), select_large_block_memory(0){
        alloc_memory();
        init_();
    }

    inline void set(int index) {
        bit_vector[index/UNIT_SIZE] |= (1u << index % UNIT_SIZE);
    }

    inline int isSet(int index) {
        return (bit_vector[index/UNIT_SIZE] >> index % UNIT_SIZE) & 1u;
    }

    inline int bit_mask(int start, int end) {
        int left_id = start / UNIT_SIZE, right_id = (end - 1) / UNIT_SIZE;
        int left_pos = start % UNIT_SIZE, right_pos = (end - 1) % UNIT_SIZE + 1;
        return (left_id == right_id)
                    ? ((bit_vector[left_id] >> left_pos) & ((1u << (right_pos-left_pos))-1u))
                        : ((bit_vector[left_id] >> left_pos) | ((bit_vector[right_id] & ((1u << right_pos)-1u)) << (UNIT_SIZE - left_pos)));
    }

    void build_rank_block() {
        int bit_index = 0, small_block_index = 0;
        block_rank[0] = 0;
        for(int i = 0; i < BLOCK_COUNT; i++){
            block_rank[i+1] = block_rank[i];
            small_block_rank[small_block_index++] = 0;
            for(int j = 0; j < SMALL_BLOCK_COUNT; j++){
                small_block_rank[small_block_index] = small_block_rank[small_block_index-1];
                for(int k = 0; k < SMALL_BLOCK_SIZE; k++){
                    small_block_rank[small_block_index] += isSet(bit_index++);
                }
                small_block_index++;
            }
            block_rank[i+1] += small_block_rank[small_block_index-1];
        }
    }

    void build_rank_lookup_table() {
        for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
            lookup_table_rank[i][0] = 0;
            for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
                lookup_table_rank[i][j+1] = lookup_table_rank[i][j] + ((i >> j) & 1);
            }
        }
    }

    void build_logn_ary_rec(short int* st_count, short int* des_index, int left, int right) {
        if(right - left == 1) return;
        int cur_len = (right - left) / DESCENDANT_COUNT + 1;
        int cur_index = left, nx_index = right;
        for(int i = 0; i < cur_len-1; i++){
            des_index[nx_index] = cur_index;
            st_count[nx_index] = st_count[cur_index += DESCENDANT_COUNT-1];
            nx_index++;
        }
        des_index[nx_index] = cur_index;
        st_count[nx_index] = st_count[right-1];
        build_logn_ary_rec(st_count, des_index, right, nx_index+1);
    }

    // [left, right)
    void build_select_logn_ary(block* bl, int left, int right) {
        short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index;
        int leaf_count = (right - left) / SMALL_BLOCK_SIZE + 1;
        int cur_len = leaf_count, node_len = cur_len;
        while(cur_len > 1){
            cur_len = cur_len / DESCENDANT_COUNT + 1;
            node_len += cur_len;
        }
        bl->ary_len = node_len;
        st_count = new short int[node_len];
        des_index = new short int[node_len];
        select_small_block_memory += 4 * node_len;
        memset(st_count, 0, sizeof(short int)*node_len);
        memset(des_index, -1, sizeof(short int)*node_len);
        int bit_index = left;
        for(int i = 0; i < leaf_count - 1; i++){
            for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
                st_count[i] += isSet(bit_index++);
            }
            st_count[i+1] += st_count[i];
        }
        // 余りを計算
        for(;bit_index < right; bit_index++) {
            st_count[leaf_count-1] += isSet(bit_index++);
        }
        build_logn_ary_rec(st_count, des_index, 0, leaf_count);
    }

    void build_select_data_structure() {
        int bit_length = 0, bit_sum = 0, block_index = 0;
        int block_select_count = block_rank[BLOCK_COUNT] / BLOCK_SIZE + 1;
        block_select = new block[block_select_count];
        std::cerr << "block_pointer: " << block_select_count*4 << "bytes\n";
        int* temp_pos = new int[BLOCK_SIZE];
        select_large_block_memory += 4 * BLOCK_SIZE;
        for(int i = 0; i < BIT_SIZE; i++){
            bit_length++;
            if(isSet(i)){
                temp_pos[bit_sum] = i;
                bit_sum++;
                if(bit_sum == BLOCK_SIZE){
                    if(bit_length <= BOUNDARY_LENGTH){
                        build_select_logn_ary(&block_select[block_index], i+1-bit_length, i+1);
                    }else{
                        block_select[block_index].one_pos = temp_pos;
                        temp_pos = new int[BLOCK_SIZE];
                        select_large_block_memory += 4 * BLOCK_SIZE;
                    }
                    bit_length = bit_sum = 0;
                    block_index++;
                }
            }
        }
    }

    void build_select_lookup_table() {
        for(int i = 0; i < (1 << SMALL_BLOCK_SIZE); i++){
            int num = 0;
            for(int j = 0; j < SMALL_BLOCK_SIZE; j++){
                if((i >> j)&1) lookup_table_select[i][num++] = j;
            }
        }
    }

    void build() {
        build_rank_block();
        build_rank_lookup_table();
        build_select_data_structure();
        build_select_lookup_table();
        std::cerr << "SELECT_SMALL_BLOCK: " << select_small_block_memory << "bytes\n";
        std::cerr << "SELECT_LARGE_BLOCK: " << select_large_block_memory << "bytes\n";
    }

    inline int rank(int x) {
        return block_rank[x / BLOCK_SIZE] + small_block_rank[x / BLOCK_SIZE * (SMALL_BLOCK_COUNT+1) + x / SMALL_BLOCK_SIZE]
                    + lookup_table_rank[bit_mask(x - x % SMALL_BLOCK_SIZE, x)][x % SMALL_BLOCK_SIZE];
    }

    int select_search(block* bl, const int start, const int target_pos, int cur_index) {
        if(bl[index]->descendant_start_index < 0){
            return lookup_table_select[bit_mask(start + cur_index*SMALL_BLOCK_SIZE, start + (cur_index+1)*SMALL_BLOCK_SIZE)][target_pos];
        }
        short int* st_count = bl->subtree_count_sum, *des_index = bl->descendant_start_index;
        int low = des_index[cur_index], high = low + DESCENDANT_COUNT - 1;
        if(target_pos <= st_count[low]) return select_search(bl, start, target_pos, low);
        while(high - low > 1){
            int mid = (low + high) / 2;
            if(target_pos > st_count[mid]){
                low = mid;
            }else{
                high = mid;
            }
        }
        return select_search(bl, start, target_pos, high);
    }

    // x番目の1の位置を返す(0_indexed)
    inline int select(int x) {
        return block_select[x / BLOCK_SIZE]->isSmall()
                    ? select_search(block_select[x / BLOCK_SIZE], x / BLOCK_SIZE * BLOCK_SIZE, x % BLOCK_SIZE + 1, block_select[x / BLOCK_SIZE]->ary_len)
                     : block_select[x / BLOCK_SIZE]->one_pos[x % BLOCK_SIZE];
    }
};

verify 用の問題

verify していません(verify 問題を知らない)