$\newcommand{\O}{\mathrm{O}}$ My Algorithm : kopricky アルゴリズムライブラリ

kopricky アルゴリズムライブラリ

RangeMaxUpdateQuery and RangeSumQuery

コードについての説明

区間 $\max$ への更新 $a_i \leftarrow \max \{a_i, x\} (i \in [l, r))$, 区間和 $\underset{l \le i < r}{\mathrm{sum}} a_i$ のクエリを効率的に処理できる(RangeMinUpdateQuery, RangaUpdateQuery, RangaAddQuery を付け加えたりなどは以下の計算量解析が回るので容易にできる).
コドフォのブログ に書いてあったテクニックで自分は最初知ったとき結構びっくりした.
実装もそこまでめんどくさくなく, 通常の遅延セグ木の更新の際の探索において
「$b \le l || r \le a$」 なら探索終了の部分を 「$b \le l || r \le a || min \_ val[k] \ge x$」 なら探索終了に変更し, ($min \_ val[k]$ とは対応する区間内の最小値を表す)
「$a \le l \& \& r \le b$」 なら遅延処理するという部分を 「$a \le l \& \& r \le b \& \& second \_ val[k] > x$」 なら遅延処理するに変更する. ($second \_ val[k]$ とは対応する区間内の $2$ 番目((注)strict に $2$ 番目) に小さい値を表す)
のように変更するだけである. もちろん $min \_ val$, $second \_ val$ は合わせて計算しておく.
前者の条件の変更は探索終了の条件をきつくしたので計算量的にはむしろプラスに働く. 問題は後者の条件の方で $second \_ val[k] \le x$ が成り立つ場合, 遅延セグ木では遅延処理にのっけて終了していたところを余分に探索することになる. 以下ではこの点について説明する.
セグ木が保持するノード($2n$ 個ぐらいある) に対応する区間 $[l, r)$ について $D(l, r)$ を区間 $[l, r)$ に含まれる数の種類数とする. またすべてのノードの $D(l, r)$ の和を $C$ とする. ここで $C = \O (n \log n)$ であることに注意する.
ここで遅延セグ木と比べて余分に探索しなければいけないとき, つまり $second \_ val[k] \le x$ のときは探索終了後にはこの区間内の $min \_ val[k] と second \_ val[k]$ が $x$ に更新されている. 言い換えると余分に一回探索するごとにその区間の範囲内の数の種類数が少なくとも $1$ 減ることになる. もちろん $C$ も少なくとも $1$ 減る.
さらに $C = \O (n \log n )$ であることと合わせると, クエリをすべて処理した時にかかる余分な探索は $\O (n \log n)$ であることがいえ, これより 1 回のクエリの処理にかかる計算量はならし $\O (\log n)$ となる.
このたぐいのならしで計算量が落ちる segtree はいろいろな種類があるらしいが基本的にはエントロピー的なもの(ポテンシャルという)の変化と各更新の計算回数を対応づけることが大事でこれはならし計算量でよく用いる解析である. 今回で言うと $1$ 回余分にたどる操作と種類数の変化(ポテンシャルの変化)が対応している.

時間計算量: 構築 $\O (n)$, 各クエリ ならし $\O(\log n)$

コード

template<typename T> class segtree {
private:
    int n,sz;
    vector<T> node, min_val, second_val, lazy;
    vector<int> count;
    vector<bool> lazyFlag;
    void update(int id){
        node[id] = node[2*id+1] + node[2*id+2];
        if(min_val[2*id+1] < min_val[2*id+2]){
            min_val[id] = min_val[2*id+1];
            second_val[id] = min(second_val[2*id+1], min_val[2*id+2]);
            count[id] = count[2*id+1];
        }else if(min_val[2*id+1] > min_val[2*id+2]){
            min_val[id] = min_val[2*id+2];
            second_val[id] = min(min_val[2*id+1], second_val[2*id+2]);
            count[id] = count[2*id+2];
        }else{
            min_val[id] = min_val[2*id+1];
            second_val[id] = min(second_val[2*id+1], second_val[2*id+2]);
            count[id] = count[2*id+1] + count[2*id+2];
        }
    }
public:
    segtree(const vector<T>& v) : n(1), sz((int)v.size()){
        while(n < sz) n *= 2;
        node.resize(2*n-1, 0);
        lazy.resize(2*n-1, 0);
        lazyFlag.resize(2*n-1, false);
        min_val.resize(2*n-1, numeric_limits<T>::max());
        second_val.resize(2*n-1, numeric_limits<T>::max());
        count.resize(2*n-1, 1);
        for(int i = 0; i < sz; i++){
            node[i+n-1] = min_val[i+n-1] = v[i];
        }
        for(int i=n-2; i>=0; i--){
            update(i);
        }
    }
    void eval(int k, int l, int r) {
        if(lazyFlag[k]){
            if(lazy[k] > min_val[k]){
                node[k] += (lazy[k] - min_val[k]) * count[k];
                min_val[k] = lazy[k];
                if(r - l > 1){
                    lazy[k*2+1] = lazy[k*2+2] = lazy[k];
                    lazyFlag[k*2+1] = lazyFlag[k*2+2] = true;
                }
            }
            lazyFlag[k] = false;
        }
    }
    void range(int a, int b, T x, int k=0, int l=0, int r=-1){
        if(r < 0) r = n;
        eval(k, l, r);
        if(b <= l || r <= a || min_val[k] >= x){
            return;
        }
        if(a <= l && r <= b && second_val[k] > x) {
            lazy[k] = x;
            lazyFlag[k] = true;
            eval(k, l, r);
        }else{
            range(a, b, x, 2*k+1, l, (l+r)/2);
            range(a, b, x, 2*k+2, (l+r)/2, r);
            update(k);
        }
    }
    T query(int a, int b, int k=0, int l=0, int r=-1) {
        if(r < 0) r = n;
        eval(k, l, r);
        if(b <= l || r <= a){
            return 0;
        }
        if(a <= l && r <= b){
            return node[k];
        }
        T vl = query(a, b, 2*k+1, l, (l+r)/2);
        T vr = query(a, b, 2*k+2, (l+r)/2, r);
        return vl + vr;
    }
    void print()
    {
        for(int i = 0; i < sz; i++){
            cout << query(i,i+1) << " ";
        }
        cout << endl;
    }
};

コード(RangeAddQuery などを追加したもの)

template<typename T> class segtree {
private:
    int n, sz, h;
    const T inf;
    vector<T> node, min_val, second_val, max_val, lazy, lazy_add;
    vector<int> count;
    void update(int id){
        node[id] = node[2*id] + node[2*id+1];
        max_val[id] = max(max_val[2*id], max_val[2*id+1]);
        if(min_val[2*id] < min_val[2*id+1]){
            min_val[id] = min_val[2*id];
            second_val[id] = min(second_val[2*id], min_val[2*id+1]);
            count[id] = count[2*id];
        }else if(min_val[2*id] > min_val[2*id+1]){
            min_val[id] = min_val[2*id+1];
            second_val[id] = min(min_val[2*id], second_val[2*id+1]);
            count[id] = count[2*id+1];
        }else{
            min_val[id] = min_val[2*id];
            second_val[id] = min(second_val[2*id], second_val[2*id+1]);
            count[id] = count[2*id] + count[2*id+1];
        }
    }
public:
    segtree(const vector<T>& v)
        : n(1), sz((int)v.size()), h(0), inf(numeric_limits<T>::max() / 2){
        while(n < sz) n *= 2, ++h;
        node.resize(2*n, 0), min_val.resize(2*n, inf), second_val.resize(2*n, inf);
        max_val.resize(2*n, -inf), lazy.resize(2*n, inf), lazy_add.resize(2*n, 0);
        count.resize(2*n, 1);
        for(int i = 0; i < sz; ++i){
            node[i+n] = min_val[i+n] = max_val[i+n] = v[i];
        }
        for(int i = n-1; i >= 1; --i) update(i);
    }
    void eval(int k, int length){
        if(lazy_add[k] != 0){
            node[k] += lazy_add[k] * length;
            min_val[k] += lazy_add[k], second_val[k] += lazy_add[k], max_val[k] += lazy_add[k];
            if(k < n){
                lazy[2*k] += lazy_add[k], lazy[2*k+1] += lazy_add[k];
                lazy_add[2*k] += lazy_add[k], lazy_add[2*k+1] += lazy_add[k];
            }
            lazy_add[k] = 0;
        }
        if(lazy[k] < (inf >> 1)){
            if(lazy[k] > min_val[k]){
                node[k] += (lazy[k] - min_val[k]) * count[k];
                min_val[k] = lazy[k], max_val[k] = max(max_val[k], lazy[k]);
                if(k < n){
                    lazy[2*k] = lazy[2*k+1] = lazy[k];
                }
            }
            lazy[k] = inf;
        }
    }
    void range_maxupdate(int a, int b, T x, int k=1, int l=0, int r=-1){
        if(r < 0) r = n;
        eval(k, r-l);
        if(b <= l || r <= a || min_val[k] >= x) return;
        if(a <= l && r <= b && second_val[k] > x){
            lazy[k] = x, eval(k, r-l);
        }else{
            range_maxupdate(a, b, x, 2*k, l, (l+r)>>1);
            range_maxupdate(a, b, x, 2*k+1, (l+r)>>1, r);
            update(k);
        }
    }
    void range_add(int a, int b, T x, int k=1, int l=0, int r=-1){
        if(r < 0) r = n;
        eval(k, r-l);
        if(b <= l || r <= a) return;
        if(a <= l && r <= b){
            lazy_add[k] += x, eval(k, r-l);
        }else{
            range_add(a, b, x, 2*k, l, (l+r)>>1);
            range_add(a, b, x, 2*k+1, (l+r)>>1, r);
            update(k);
        }
    }
    T query_min(int a, int b){
        int length = n;
        a += n, b += n - 1;
        for(int i = h; i > 0; i--, length >>= 1){
            eval(a >> i, length), eval(b >> i, length);
        }
        b++;
        T res1 = inf, res2 = inf;
        while(a < b){
            if(a & 1) eval(a, length), res1 = min(res1, min_val[a++]);
            if(b & 1) eval(--b, length), res2 = min(res2, min_val[b]);
            a >>= 1, b >>= 1, length <<= 1;
        }
        return min(res1, res2);
    }
    T query_max(int a, int b){
        int length = n;
        a += n, b += n - 1;
        for(int i = h; i > 0; i--, length >>= 1){
            eval(a >> i, length), eval(b >> i, length);
        }
        b++;
        T res1 = -inf, res2 = -inf;
        while(a < b) {
            if(a & 1) eval(a, length), res1 = max(res1, max_val[a++]);
            if(b & 1) eval(--b, length), res2 = max(res2, max_val[b]);
            a >>= 1, b >>= 1, length <<= 1;
        }
        return max(res1, res2);
    }
    T query_sum(int a, int b){
        int length = n;
        a += n, b += n - 1;
        for(int i = h; i > 0; i--, length >>= 1){
            eval(a >> i, length), eval(b >> i, length);
        }
        b++;
        T res1 = 0, res2 = 0;
        while(a < b) {
            if(a & 1) eval(a, length), res1 += node[a++];
            if(b & 1) eval(--b, length), res2 += node[b];
            a >>= 1, b >>= 1, length <<= 1;
        }
        return res1 + res2;
    }
    void print()
    {
        for(int i = 0; i < sz; i++){
            cout << query_sum(i,i+1) << " ";
        }
        cout << endl;
    }
};

verify 用の問題

Atcoder : Shortest Path on a Line 提出コード(別にこれを使わなくても解ける)