$\newcommand{\O}{\mathrm{O}}$ My Algorithm : kopricky アルゴリズムライブラリ

kopricky アルゴリズムライブラリ

Accuracy Multi Precision

コードについての説明

精度保証付き浮動小数点数型のコード. (任意精度)多倍長整数 を応用して書いてみたが正直あまり速度は速くないと思うし, もっと上手く書けると思うのであまり納得がいっていない.
verify ももう少し行う必要がある気がしていて, おすすめはしない. Atcoder なら boost 使ったほうが良さそう.

コード

template<int acc> class AMP : public deque<int> {
private:

    static constexpr int root = 5;
    static constexpr int MOD_ = 924844033;

    static void trim_digit(AMP& num){
        while((int)num.size() > 1 && num.back() == 0) num.pop_back();
        if((int)num.size() == 1 && num[0] == 0){ num.zero = true; return; }
        while((int)num.size() < acc) num.push_front(0), num.ex--;
        while((int)num.size() > acc + 1) num.pop_front(), num.ex++;
        rounding(num);
    }
    static void rounding(AMP& num){
        if((int)num.size() != acc + 1) return;
        if(num[0] >= 5){
            int pos = 1;
            do{ num[pos]++;
                if(num[pos] != 10) break;
                num[pos] = 0;
            }while(++pos <= acc);
            if(pos == acc+1) num.push_back(1), num.pop_back(), num.ex++;
        }
        num.pop_front(), num.ex++;
    }
    static bool abs_less(const AMP& a, const AMP& b){
        if(a.ex != b.ex) return a.ex < b.ex;
        for(int index = acc - 1; index >= 0; index--){
            if(a[index] != b[index]) return a[index] < b[index];
        }
        return false;
    }
    static void num_sbst(AMP& a, const AMP& b){
        a.resize(acc), a.zero = false, a.ex = b.ex;
        for(int i = 0; i < acc; i++) a[i] = b[i];
    }
    static void add(const AMP& a, const AMP& b, AMP& res){
        int diff = a.ex - b.ex, carry = 0;
        if(abs(diff) > acc) return (diff > 0) ? num_sbst(res, a) : num_sbst(res, b);
        if(diff >= 0){
            num_sbst(res, a);
            if(diff) res.push_front(0), res.ex--;
            for(int i = !diff; i <= acc; i++){
                int val = res[i-!diff] + (i <= acc-diff ? b[i+diff-1] : 0) + carry;
                carry = val/10;
                res[i-!diff] = val%10;
            }
            if(carry) res.push_back(1);
        }else{
            num_sbst(res, b);
            res.push_front(0), res.ex--;
            for(int i = 0; i <= acc; i++){
                int val = res[i] + (i <= acc+diff ? a[i-diff-1] : 0) + carry;
                carry = val/10;
                res[i] = val%10;
            }
            if(carry) res.push_back(1);
        }
        trim_digit(res);
    }
    static void sub(const AMP& a, const AMP& b, AMP& res){
        int diff = a.ex - b.ex, carry = 0;
        num_sbst(res, a);
        if(diff > acc) return;
        if(diff){
            res.push_front(0), res.ex--;
            int carry = 0;
            for(int i = 0; i <= acc; i++){
                int val = res[i] - carry - (i <= acc-diff ? b[i+diff-1] : 0);
                if(val < 0){
                    carry = 1, val += 10;
                }else{
                    carry = 0;
                }
                res[i] = val;
            }
        }else{
            for(int i = 0; i < acc; i++){
                int val = res[i] - carry - b[i];
                if(val < 0){
                    carry = 1, val += 10;
                }else{
                    carry = 0;
                }
                res[i] = val;
            }
        }
        trim_digit(res);
    }
    static int add_(const int x, const int y) { return (x + y < MOD_) ? x + y : x + y - MOD_; }
    static int mul_(const int x, const int y) { return (long long)x * y % MOD_; }
    static int power(int x, int n){
        int res = 1;
        while(n > 0){
            if(n & 1) res = mul_(res, x);
            x = mul_(x, x);
            n >>= 1;
        }
        return res;
    }
    static int inverse(const int x) { return power(x, MOD_ - 2); }
    static void ntt(vector<int>& a, const bool rev = false){
        int i,j,k,s,t,v,w,wn;
        const int size = (int)a.size();
        const int height = (int)log2(2 * size - 1);
        for(i = 0; i < size; i++){
            j = 0;
            for(k = 0; k < height; k++) j |= (i >> k & 1) << (height - 1 - k);
            if(i < j) std::swap(a[i], a[j]);
        }
        for(i = 1; i < size; i *= 2) {
            w = power(root, (MOD_ - 1) / (i * 2));
            if(rev) w = inverse(w);
            for(j = 0; j < size; j += i * 2){
                wn = 1;
                for(k = 0; k < i; k++){
                    s = a[j + k], t = mul_(a[j + k + i], wn);
                    a[j + k] = add_(s, t);
                    a[j + k + i] = add_(s, MOD_ - t);
                    wn = mul_(wn, w);
                }
            }
        }
        if(rev){
            v = inverse(size);
            for (i = 0; i < size; i++) a[i] = mul_(a[i], v);
        }
    }
    static void mul(const AMP& a, const AMP& b, AMP& res){
        const int size = (int)a.size() + (int)b.size() - 1;
        int t = 1;
        while (t < size) t <<= 1;
        vector<int> A(t, 0), B(t, 0);
        for(int i = 0; i < (int)a.size(); i++) A[i] = a[i];
        for(int i = 0; i < (int)b.size(); i++) B[i] = b[i];
        ntt(A), ntt(B);
        for(int i = 0; i < t; i++) A[i] = mul_(A[i], B[i]);
        ntt(A, true);
        res.resize(size);
        int carry = 0;
        for(int i = 0; i < size; i++){
            int val = A[i] + carry;
            carry = val / 10;
            res[i] = val % 10;
        }
        if(carry) res.push_back(carry);
        trim_digit(res);
    }

    class MPI : public deque<int> {
    public:
        MPI(){ push_back(0); }
        inline static void trim_digit(MPI& num){
            while(num.back() == 0 && (int)num.size() >= 2) num.pop_back();
        }
        bool isZero() const {
            return (int)size() == 1 && (*this)[0] == 0;
        }
        static void add(const MPI& a, const MPI& b, MPI& res){
            res = a;
            int mx = (int)max(a.size(), b.size());
            res.resize(mx, 0);
            int carry = 0;
            for(int i = 0; i < mx; i++){
                int val = res[i] + ((i < (int)b.size()) ? b[i] : 0) + carry;
                carry = val/10;
                res[i] = val%10;
            }
            if(carry) res.push_back(1);
        }
        static void sub(const MPI& a, const MPI& b, MPI& res){
            res = a;
            int carry = 0;
            for(int i = 0; i < (int)res.size(); i++){
                int val = res[i] - carry - ((i < (int)b.size()) ? b[i] : 0);
                if(val < 0){
                    carry = 1, val += 10;
                }else{
                    carry = 0;
                }
                res[i] = val;
            }
            trim_digit(res);
        }
        bool operator<(const MPI& another) const {
            if(size() != another.size()) return size() < another.size();
            for(int index = (int)size() - 1; index >= 0; index--){
                if((*this)[index] != another[index]) return (*this)[index] < another[index];
            }
            return false;
        }
        static bool div_geq(const MPI& mod, const MPI& num){
            if((int)mod.size() != (int)num.size()) return (int)mod.size() > (int)num.size();
            int n = (int)mod.size();
            for(int i = 0; i < n; i++){
                if(mod[n-1-i] != num[n-1-i]){
                    return mod[n-1-i] > num[n-1-i];
                }
            }
            return true;
        }
        static void div_(const MPI& a, const MPI& b, MPI& res){
            vector<MPI> mult(9);
            MPI mod;
            mult[0] = b;
            for(int i = 0; i < 8; i++) add(mult[i], b, mult[i+1]);
            for(int i = (int)a.size() - 1; i >= 0; i--){
                if(mod.isZero()){
                    mod.back() = a[i];
                }else{
                    mod.push_front(a[i]);
                }
                if(div_geq(mod, b)){
                    int l = 0, r = 9;
                    while(r-l>1){
                        int mid = (l+r)/2;
                        if(mod < mult[mid]){
                            r = mid;
                        }else{
                            l = mid;
                        }
                    }
                    MPI mod_ = mod;
                    sub(mod_, mult[l], mod);
                    res.push_front(l+1);
                }else{
                    res.push_front(0);
                }
            }
            trim_digit(res);
        }
    };

    static void mpi_AMP(MPI& a, const AMP& b){
        if(b.zero){ a = MPI(); return; }
        int n = (int)b.size();
        a.resize(n);
        for(int i = 0; i < n; i++) a[i] = b[i];
    }

    static void AMP_mpi(AMP& a, const MPI& b){
        if(b.isZero()){ a = AMP(); return; }
        int n = (int)b.size();
        a.resize(n);
        for(int i = 0; i < n; i++) a[i] = b[i];
    }

public:

    friend ostream& operator<<(ostream& os, const AMP& num) {
        if(num.zero){ os << "0."; for(int i = 0; i < acc-1; i++) os << '0';
                    os << "+e0"; return os; }
        if(num.sign) os << '-';
        os << num.back() << '.';
        for(int i = 0; i < acc-1; i++) os << num[acc-2-i];
        os << 'e';
        if(num.ex+acc-1 >= 0) os << '+';
        os << num.ex+acc-1;
        return os;
    }

    friend istream& operator>>(istream& is, AMP& num) {
        string s;
        is >> s;
        num = AMP(s);
        return is;
    }

    void print_decimal(int decimal) const {
        if(zero){ cout << "0."; for(int i = 0; i < decimal; i++) cout << '0'; return; }
        if(sign) cout << '-';
        for(int i = max(ex+acc-1, 0LL); i >= -decimal; --i){
            cout << ((i-ex >= 0 && i-ex < acc) ? ((*this)[i-ex]) : 0);
            if(i == 0) cout << '.';
        }
    }

    double to_double() const {
        if(zero){ return 0.0; }
        double res = 0.0, d = 1.0;
        for(int i = 0; i < acc; i++){
            res += (*this)[i] * d, d *= 10.0;
        }
        return sign ? -res * pow(10.0, ex) : res * pow(10.0, ex);
    }

    AMP& operator=(long long num) {
        *this = AMP(num);
        return *this;
    }

    bool operator<(const AMP& num) const {
        if(zero) return !num.zero && !num.sign;
        if(num.zero) return sign;
        if(sign ^ num.sign) return sign;
        if(ex != num.ex) return (ex < num.ex) ^ sign;
        for(int index = acc - 1; index >= 0; index--){
            if((*this)[index] != num[index]) return ((*this)[index] < num[index]) ^ sign;
        }
        return false;
    }

    bool operator<(const long long num) const {
        return *this < AMP(num);
    }

    friend bool operator<(const long long num, const AMP& another){
        return AMP(num) < another;
    }

    bool operator>(const AMP& num) const {
        return num < *this;
    }

    bool operator>(const long long num) const {
        return *this > AMP(num);
    }

    friend bool operator>(const long long num, const AMP& another){
        return AMP(num) > another;
    }

    bool operator<=(const AMP& num) const {
        return !(*this > num);
    }

    bool operator<=(const long long num) const {
        return *this <= AMP(num);
    }

    friend bool operator<=(const long long num, const AMP& another){
        return AMP(num) <= another;
    }

    bool operator>=(const AMP& num) const {
        return !(*this < num);
    }

    bool operator>=(const long long num) const {
        return *this >= AMP(num);
    }

    friend bool operator>=(const long long num, const AMP& another){
        return AMP(num) >= another;
    }

    bool operator==(const AMP& num) const {
        if(zero || num.zero) return zero && num.zero;
        if(sign ^ num.sign) return false;
        if(ex != num.ex) return false;
        for(int index = acc - 1; index >= 0; index--){
            if((*this)[index] != num[index]) return false;
        }
        return true;
    }

    bool operator==(const long long num) const {
        return *this == AMP(num);
    }

    friend bool operator==(const long long num, const AMP& another){
        return AMP(num) == another;
    }

    bool operator!=(const AMP& num) const {
        return !(*this == num);
    }

    bool operator!=(const long long num) const {
        return *this != AMP(num);
    }

    friend bool operator!=(const long long num, const AMP& another){
        return AMP(num) != another;
    }

    explicit operator bool() const noexcept { return !zero; }
    bool operator!() const noexcept { return !static_cast<bool>(*this); }

    explicit operator int() const noexcept { return (int)this->to_ll(); }
    explicit operator long long() const noexcept { return this->to_ll(); }

    AMP operator+() const {
        return *this;
    }

    AMP operator-() const {
        AMP res = *this;
        res.sign = !sign;
        return res;
    }

    friend AMP abs(const AMP& num){
        AMP res = num;
        res.sign = false;
        return res;
    }

    AMP operator+(const AMP& num) const {
        if(zero){ AMP res = num; return res; }
        if(num.zero){ AMP res = *this; return res; }
        AMP res; res.sign = sign;
        if(sign != num.sign){
            if(abs_less(*this, num)){
                res.sign = num.sign;
                sub(num, *this, res);
                return res;
            }else{
                sub(*this, num, res);
                return res;
            }
        }
        add(*this, num, res);
        return res;
    }

    AMP operator+(long long num) const {
        return *this + AMP(num);
    }

    friend AMP operator+(long long a, const AMP& b){
        return b + a;
    }

    AMP& operator+=(const AMP& num){
        *this = *this + num;
        return *this;
    }

    AMP& operator+=(long long num){
        *this = *this + num;
        return *this;
    }

    AMP& operator++(){
        return *this += 1;
    }

    AMP operator++(int){
        AMP res = *this;
        *this += 1;
        return res;
    }

    AMP operator-(const AMP& num) const {
        if(zero){ AMP res = num; res.sign = !res.sign; return res; }
        if(num.zero){ AMP res = *this; return res; }
        AMP res; res.sign = sign;
        if(sign != num.sign){
            add(*this, num, res);
            return res;
        }
        if(abs_less(*this, num)){
            res.sign = !sign;
            sub(num, *this, res);
        }else{
            sub(*this, num, res);
        }
        return res;
    }

    AMP operator-(long long num) const {
        return *this - AMP(num);
    }

    friend AMP operator-(long long a, const AMP& b){
        return AMP(a) - b;
    }

    AMP& operator-=(const AMP& num){
        *this = *this - num;
        return *this;
    }

    AMP& operator-=(long long num){
        *this = *this - num;
        return *this;
    }

    AMP& operator--(){
        return *this -= 1;
    }

    AMP operator--(int){
        AMP res = *this;
        *this -= 1;
        return res;
    }

    AMP operator*(const AMP& num) const {
        if(zero || num.zero) return AMP();
        AMP res; res.zero = false; res.sign = (sign^num.sign);
        res.ex = ex + num.ex;
        mul(*this, num, res);
        return res;
    }

    AMP operator*(long long num) const {
        return *this * AMP(num);
    }

    friend AMP operator*(long long a, const AMP& b){
        return b * a;
    }

    AMP& operator*=(const AMP& num){
        *this = *this * num;
        return *this;
    }

    AMP& operator*=(long long num){
        *this = *this * num;
        return *this;
    }

    AMP operator/(const AMP& num) const {
        assert(!num.zero);
        if(zero){ return AMP(); }
        MPI kp, res_, num_;
        mpi_AMP(num_, num);
        AMP res; res.zero = false; res.sign = (sign^num.sign); res.ex = ex - num.ex;
        mpi_AMP(kp, *this);
        if(abs_less(*this, num)) kp.push_front(0), res.ex--;
        for(int i = 0; i < acc; i++) kp.push_front(0), res.ex--;
        MPI::div_(kp, num_, res_);
        AMP_mpi(res, res_);
        trim_digit(res);
        return res;
    }

    AMP operator/(long long num) const {
        return *this / AMP(num);
    }

    friend AMP operator/(long long a, const AMP& b){
        return AMP(a) / b;
    }

    AMP& operator/=(const AMP& num){
        *this = *this / num;
        return *this;
    }

    AMP& operator/=(long long num){
        *this = *this / num;
        return *this;
    }

    AMP div2() const {
        if(zero) return AMP();
        AMP res = *this;
        int carry = 0;
        for(int i = acc-1; i >= 0; i--){
            int val = (*this)[i]+carry*10;
            carry = val%2;
            if(i != acc-1 || val >= 2){
                res[i] = val/2;
            }else{
                res[i] = 0;
            }
        }
        if(carry) res.push_front(5), res.ex--;
        trim_digit(res);
        return res;
    }

    friend AMP sqrt(AMP x){
        if(x <= AMP(0)) return AMP();
        AMP s = 1, t = x;
        while(s < t){
            s = s + s, t = t.div2();
        }
        do{ t = s, s = (x / s + s).div2();
        }while(s < t);
        trim_digit(t);
        return t;
    }

    friend AMP log10(const AMP& x){
        assert(x > AMP());
        return AMP(acc + x.ex);
    }

    friend AMP pow(AMP a, long long b){
        if(a.zero) return AMP();
        assert(b >= 0);
        AMP res(1);
        while(b){
            if(b % 2) res *= a;
            a *= a;
            b = b / 2;
        }
        return res;
    }

    bool sign, zero;
	long long ex;

    AMP() : zero(true){}

    AMP(long long val) : sign(false), zero(false), ex(0){
        if(val == 0){ zero = true; return; }
        if(val < 0) sign = true, val = -val;
        while(val) push_back(val%10), val /= 10;
        trim_digit(*this);
    }

    AMP(const string& s) : sign(false), zero(false), ex(0){
        if(s.empty() || s == "0"){ zero = true; return; }
        if(s[0] == '-') sign = true;
        for(int i = (int)s.size() - 1; i >= sign; i--){
            if(s[i] == '.'){
                ex = i + 1 - (int)s.size();
            }else{
                push_back(s[i]-'0');
            }
        }
        trim_digit(*this);
    }
};

verify 用の問題

yukicoder : すべて足すだけの簡単なお仕事です。 提出コード
yukicoder : ax^2+bx+c=0 提出コード