提交记录 19321


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004a. 【模板题】高精度乘法2 Accepted 100 1.635 ms 4316 KB C++14 34.33 KB
提交时间 评测时间
2023-04-18 19:33:37 2023-04-18 19:33:42
#include <algorithm>
#include <atomic>
#include <complex>
#include <vector>
#include <future>
#include <iostream>
#include <random>
#include <string>
#include <cstdlib>
#include <cstring>

namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using LONG = long;
    using Complex = std::complex<double>;
    constexpr double HINT_PI = 3.1415926535897932384626433832795;
    constexpr double HINT_2PI = HINT_PI * 2;

    template <typename T>
    constexpr T min_2pow(T n)
    {
        T res = 1;
        while (res < n)
        {
            res *= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T max_2pow(T n)
    {
        T res = 1;
        res <<= (sizeof(T) * 8 - 1);
        while (res > n)
        {
            res /= 2;
        }
        return res;
    }
    template <typename T>
    constexpr bool is_neg(T x)
    {
        return x < 0;
    }
    template <typename T>
    constexpr bool is_odd(T x)
    {
        return static_cast<bool>(x & 1);
    }

    template <typename T>
    constexpr std::pair<T, T> div_mod(T a, T b)
    {
        return std::make_pair(a / b, a % b);
    }
    template <typename T>
    void ary_copy(T *target, const T *source, size_t len)
    {
        if (len == 0 || target == source)
        {
            return;
        }
        if (len >= INT64_MAX)
        {
            throw("Ary too long\n");
        }
        std::memcpy(target, source, len * sizeof(T));
    }

    template <typename T>
    inline T *ary_realloc(T *ptr, size_t len)
    {
        if (len * sizeof(T) < INT64_MAX)
        {
            ptr = static_cast<T *>(realloc(ptr, len * sizeof(T)));
        }
        if (ptr == nullptr)
        {
            throw("realloc error");
        }
        return ptr;
    }
    // 从其他类型数组拷贝到复数组实部
    template <typename T>
    inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
    {
        size_t min_len = std::min(len1, len2);
        size_t i = 0;
        while (i < min_len)
        {
            target[i] = Complex(source1[i], source2[i]);
            i++;
        }
        while (i < len1)
        {
            target[i].real(source1[i]);
            i++;
        }
        while (i < len2)
        {
            target[i].imag(source2[i]);
            i++;
        }
    }
    template <typename T>
    constexpr size_t hint_log2(T n)
    {
        T res = 0;
        while (n > 1)
        {
            n /= 2;
            res++;
        }
        return res;
    }
    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        template <typename T>
        void binary_inverse_swap(T &ary, size_t len)
        {
            size_t i = 0;
            for (size_t j = 1; j < len - 1; j++)
            {
                size_t k = len >> 1;
                i ^= k;
                while (k > i)
                {
                    k >>= 1;
                    i ^= k;
                };
                if (j < i)
                {
                    std::swap(ary[i], ary[j]);
                }
            }
        }
        class ComplexTable
        {
        private:
            std::vector<Complex> table;
            INT_32 max_log_size = 2;
            INT_32 cur_log_size = 2;

            static constexpr double PI = HINT_PI;

            ComplexTable(const ComplexTable &) = delete;
            ComplexTable &operator=(const ComplexTable &) = delete;

        public:
            ~ComplexTable() {}
            // 初始化可以生成平分圆1<<shift份产生的单位根的表
            ComplexTable(UINT_32 max_shift)
            {
                max_shift = std::max<size_t>(max_shift, 2);
                max_log_size = max_shift;
                size_t ary_size = 1ull << (max_shift - 1);
                table.resize(ary_size);
                table[0] = Complex(1);
                expend(max_shift);
            }
            void expend(INT_32 shift)
            {
                if (shift > max_log_size)
                {
                    throw("FFT length too long for lut\n");
                }
                for (INT_32 i = cur_log_size + 1; i <= shift; i++)
                {
                    size_t len = 1ull << i, vec_size = len / 4;
                    table[vec_size] = Complex(1, 0);
                    for (size_t pos = 0; pos < vec_size / 2; pos++)
                    {
                        table[vec_size + pos * 2] = table[vec_size / 2 + pos];
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos += 2)
                    {
                        double cos_theta = std::cos(HINT_2PI * pos / len);
                        double sin_theta = std::sin(HINT_2PI * pos / len);
                        table[vec_size + pos] = Complex(cos_theta, sin_theta);
                        table[vec_size * 2 - pos] = Complex(sin_theta, cos_theta);
                    }
                    table[vec_size + vec_size / 2] = unit_root(len, len / 8);
                }
                cur_log_size = std::max(cur_log_size, shift);
            }
            // 返回单位圆上辐角为theta的点
            static Complex unit_root(double theta)
            {
                return std::polar<double>(1.0, theta);
            }
            // 返回单位圆上平分m份的第n个
            static Complex unit_root(size_t m, size_t n)
            {
                return unit_root((2.0 * PI * n) / m);
            }
            // shift表示圆平分为1<<shift份,n表示第几个单位根的共轭
            Complex get_complex_conj(UINT_32 shift, size_t n) const
            {
                size_t rank = 1ull << shift;
                const size_t rank_ff = rank - 1, quad_n = n << 2;
                // n &= rank_ff;
                size_t zone = quad_n >> shift; // 第几象限
                if ((quad_n & rank_ff) == 0)
                {
                    static constexpr Complex ONES_CONJ[4] = {Complex(1, 0), Complex(0, -1), Complex(-1, 0), Complex(0, 1)};
                    return ONES_CONJ[zone];
                }
                Complex tmp;
                if ((zone & 2) == 0)
                {
                    if ((zone & 1) == 0)
                    {
                        tmp = table[rank / 4 + n];
                        tmp.imag(-tmp.imag());
                    }
                    else
                    {
                        tmp = -table[rank - rank / 4 - n];
                    }
                }
                else
                {
                    if ((zone & 1) == 0)
                    {
                        tmp = table[n - (rank / 4)];
                        tmp.real(-tmp.real());
                    }
                    else
                    {
                        tmp = table[rank + rank / 4 - n];
                    }
                }
                return tmp;
            }
        };
        constexpr size_t lut_max_rank = 19;
        static ComplexTable TABLE(lut_max_rank); // 初始化fft表
        // 2点fft
        inline void fft_2point(Complex& sum, Complex& diff)
        {
            Complex tmp0 = sum;
            Complex tmp1 = diff;
            sum = tmp0 + tmp1;
            diff = tmp0 - tmp1;
        }
        // 4点fft
        inline void fft_4point(Complex* input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp2 + tmp3;
            input[rank * 2] = tmp0 - tmp1;
            input[rank * 3] = tmp2 - tmp3;
        }
        inline void fft_dit_4point(Complex* input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp1);
            fft_2point(tmp2, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 2] = tmp0 - tmp2;
            input[rank * 3] = tmp1 - tmp3;
        }
        inline void fft_dit_8point(Complex* input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];
            Complex tmp4 = input[rank * 4];
            Complex tmp5 = input[rank * 5];
            Complex tmp6 = input[rank * 6];
            Complex tmp7 = input[rank * 7];
            fft_2point(tmp0, tmp1);
            fft_2point(tmp2, tmp3);
            fft_2point(tmp4, tmp5);
            fft_2point(tmp6, tmp7);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());
            tmp7 = Complex(tmp7.imag(), -tmp7.real());

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            fft_2point(tmp4, tmp6);
            fft_2point(tmp5, tmp7);
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
            tmp6 = Complex(tmp6.imag(), -tmp6.real());
            tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());

            input[0] = tmp0 + tmp4;
            input[rank] = tmp1 + tmp5;
            input[rank * 2] = tmp2 + tmp6;
            input[rank * 3] = tmp3 + tmp7;
            input[rank * 4] = tmp0 - tmp4;
            input[rank * 5] = tmp1 - tmp5;
            input[rank * 6] = tmp2 - tmp6;
            input[rank * 7] = tmp3 - tmp7;
        }
        inline void fft_dit_16point(Complex* input, size_t rank = 1)
        {
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            static constexpr double cos_1_16 = 0.92387953251128675612818318939679;
            static constexpr double sin_1_16 = 0.3826834323650897717284599840304;
            static constexpr Complex w1(cos_1_16, -sin_1_16), w3(sin_1_16, -cos_1_16);
            static constexpr Complex w5(-sin_1_16, -cos_1_16), w7(-cos_1_16, -sin_1_16);

            fft_dit_8point(input, rank);
            fft_dit_8point(input + rank * 8, rank);

            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 8];
            Complex tmp3 = input[rank * 9] * w1;
            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 8] = tmp0 - tmp2;
            input[rank * 9] = tmp1 - tmp3;

            tmp0 = input[rank * 2];
            tmp1 = input[rank * 3];
            tmp2 = input[rank * 10];
            tmp3 = input[rank * 11] * w3;
            tmp2 = cos_1_8 * Complex(tmp2.imag() + tmp2.real(), tmp2.imag() - tmp2.real());
            input[rank * 2] = tmp0 + tmp2;
            input[rank * 3] = tmp1 + tmp3;
            input[rank * 10] = tmp0 - tmp2;
            input[rank * 11] = tmp1 - tmp3;

            tmp0 = input[rank * 4];
            tmp1 = input[rank * 5];
            tmp2 = input[rank * 12];
            tmp3 = input[rank * 13] * w5;
            tmp2 = Complex(tmp2.imag(), -tmp2.real());
            input[rank * 4] = tmp0 + tmp2;
            input[rank * 5] = tmp1 + tmp3;
            input[rank * 12] = tmp0 - tmp2;
            input[rank * 13] = tmp1 - tmp3;

            tmp0 = input[rank * 6];
            tmp1 = input[rank * 7];
            tmp2 = input[rank * 14];
            tmp3 = input[rank * 15] * w7;
            tmp2 = -cos_1_8 * Complex(tmp2.real() - tmp2.imag(), tmp2.real() + tmp2.imag());
            input[rank * 6] = tmp0 + tmp2;
            input[rank * 7] = tmp1 + tmp3;
            input[rank * 14] = tmp0 - tmp2;
            input[rank * 15] = tmp1 - tmp3;
        }
        inline void fft_dit_32point(Complex* input, size_t rank = 1)
        {
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            static constexpr double cos_1_16 = 0.92387953251128675612818318939679;
            static constexpr double sin_1_16 = 0.3826834323650897717284599840304;
            static constexpr double cos_1_32 = 0.98078528040323044912618223613424;
            static constexpr double sin_1_32 = 0.19509032201612826784828486847702;
            static constexpr double cos_3_32 = 0.83146961230254523707878837761791;
            static constexpr double sin_3_32 = 0.55557023301960222474283081394853;
            static constexpr Complex w1(cos_1_32, -sin_1_32), w2(cos_1_16, -sin_1_16), w3(cos_3_32, -sin_3_32);
            static constexpr Complex w5(sin_3_32, -cos_3_32), w6(sin_1_16, -cos_1_16), w7(sin_1_32, -cos_1_32);
            static constexpr Complex w9(-sin_1_32, -cos_1_32), w10(-sin_1_16, -cos_1_16), w11(-sin_3_32, -cos_3_32);
            static constexpr Complex w13(-cos_3_32, -sin_3_32), w14(-cos_1_16, -sin_1_16), w15(-cos_1_32, -sin_1_32);

            fft_dit_16point(input, rank);
            fft_dit_16point(input + rank * 16, rank);

            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 16];
            Complex tmp3 = input[rank * 17] * w1;
            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 16] = tmp0 - tmp2;
            input[rank * 17] = tmp1 - tmp3;

            tmp0 = input[rank * 2];
            tmp1 = input[rank * 3];
            tmp2 = input[rank * 18] * w2;
            tmp3 = input[rank * 19] * w3;
            input[rank * 2] = tmp0 + tmp2;
            input[rank * 3] = tmp1 + tmp3;
            input[rank * 18] = tmp0 - tmp2;
            input[rank * 19] = tmp1 - tmp3;

            tmp0 = input[rank * 4];
            tmp1 = input[rank * 5];
            tmp2 = input[rank * 20];
            tmp3 = input[rank * 21] * w5;
            tmp2 = cos_1_8 * Complex(tmp2.imag() + tmp2.real(), tmp2.imag() - tmp2.real());
            input[rank * 4] = tmp0 + tmp2;
            input[rank * 5] = tmp1 + tmp3;
            input[rank * 20] = tmp0 - tmp2;
            input[rank * 21] = tmp1 - tmp3;

            tmp0 = input[rank * 6];
            tmp1 = input[rank * 7];
            tmp2 = input[rank * 22] * w6;
            tmp3 = input[rank * 23] * w7;
            input[rank * 6] = tmp0 + tmp2;
            input[rank * 7] = tmp1 + tmp3;
            input[rank * 22] = tmp0 - tmp2;
            input[rank * 23] = tmp1 - tmp3;

            tmp0 = input[rank * 8];
            tmp1 = input[rank * 9];
            tmp2 = input[rank * 24];
            tmp3 = input[rank * 25] * w9;
            tmp2 = Complex(tmp2.imag(), -tmp2.real());
            input[rank * 8] = tmp0 + tmp2;
            input[rank * 9] = tmp1 + tmp3;
            input[rank * 24] = tmp0 - tmp2;
            input[rank * 25] = tmp1 - tmp3;

            tmp0 = input[rank * 10];
            tmp1 = input[rank * 11];
            tmp2 = input[rank * 26] * w10;
            tmp3 = input[rank * 27] * w11;
            input[rank * 10] = tmp0 + tmp2;
            input[rank * 11] = tmp1 + tmp3;
            input[rank * 26] = tmp0 - tmp2;
            input[rank * 27] = tmp1 - tmp3;

            tmp0 = input[rank * 12];
            tmp1 = input[rank * 13];
            tmp2 = input[rank * 28];
            tmp3 = input[rank * 29] * w13;
            tmp2 = -cos_1_8 * Complex(tmp2.real() - tmp2.imag(), tmp2.real() + tmp2.imag());
            input[rank * 12] = tmp0 + tmp2;
            input[rank * 13] = tmp1 + tmp3;
            input[rank * 28] = tmp0 - tmp2;
            input[rank * 29] = tmp1 - tmp3;

            tmp0 = input[rank * 14];
            tmp1 = input[rank * 15];
            tmp2 = input[rank * 30] * w14;
            tmp3 = input[rank * 31] * w15;

            input[rank * 14] = tmp0 + tmp2;
            input[rank * 15] = tmp1 + tmp3;
            input[rank * 30] = tmp0 - tmp2;
            input[rank * 31] = tmp1 - tmp3;
        }
        inline void fft_dif_4point(Complex* input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp0 - tmp1;
            input[rank * 2] = tmp2 + tmp3;
            input[rank * 3] = tmp2 - tmp3;
        }
        inline void fft_dif_8point(Complex* input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];
            Complex tmp4 = input[rank * 4];
            Complex tmp5 = input[rank * 5];
            Complex tmp6 = input[rank * 6];
            Complex tmp7 = input[rank * 7];

            fft_2point(tmp0, tmp4);
            fft_2point(tmp1, tmp5);
            fft_2point(tmp2, tmp6);
            fft_2point(tmp3, tmp7);
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
            tmp6 = Complex(tmp6.imag(), -tmp6.real());
            tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            fft_2point(tmp4, tmp6);
            fft_2point(tmp5, tmp7);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());
            tmp7 = Complex(tmp7.imag(), -tmp7.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp0 - tmp1;
            input[rank * 2] = tmp2 + tmp3;
            input[rank * 3] = tmp2 - tmp3;
            input[rank * 4] = tmp4 + tmp5;
            input[rank * 5] = tmp4 - tmp5;
            input[rank * 6] = tmp6 + tmp7;
            input[rank * 7] = tmp6 - tmp7;
        }
        inline void fft_dif_16point(Complex* input, size_t rank = 1)
        {
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            static constexpr double cos_1_16 = 0.92387953251128675612818318939679;
            static constexpr double sin_1_16 = 0.3826834323650897717284599840304;
            static constexpr Complex w1(cos_1_16, -sin_1_16), w3(sin_1_16, -cos_1_16);
            static constexpr Complex w5(-sin_1_16, -cos_1_16), w7(-cos_1_16, -sin_1_16);

            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 8];
            Complex tmp3 = input[rank * 9];
            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 8] = tmp0 - tmp2;
            input[rank * 9] = (tmp1 - tmp3) * w1;

            tmp0 = input[rank * 2];
            tmp1 = input[rank * 3];
            tmp2 = input[rank * 10];
            tmp3 = input[rank * 11];
            fft_2point(tmp0, tmp2);
            tmp2 = cos_1_8 * Complex(tmp2.imag() + tmp2.real(), tmp2.imag() - tmp2.real());
            input[rank * 2] = tmp0;
            input[rank * 3] = tmp1 + tmp3;
            input[rank * 10] = tmp2;
            input[rank * 11] = (tmp1 - tmp3) * w3;

            tmp0 = input[rank * 4];
            tmp1 = input[rank * 5];
            tmp2 = input[rank * 12];
            tmp3 = input[rank * 13];
            fft_2point(tmp0, tmp2);
            tmp2 = Complex(tmp2.imag(), -tmp2.real());
            input[rank * 4] = tmp0;
            input[rank * 5] = tmp1 + tmp3;
            input[rank * 12] = tmp2;
            input[rank * 13] = (tmp1 - tmp3) * w5;

            tmp0 = input[rank * 6];
            tmp1 = input[rank * 7];
            tmp2 = input[rank * 14];
            tmp3 = input[rank * 15];
            fft_2point(tmp0, tmp2);
            tmp2 = -cos_1_8 * Complex(tmp2.real() - tmp2.imag(), tmp2.real() + tmp2.imag());
            input[rank * 6] = tmp0;
            input[rank * 7] = tmp1 + tmp3;
            input[rank * 14] = tmp2;
            input[rank * 15] = (tmp1 - tmp3) * w7;

            fft_dif_8point(input, rank);
            fft_dif_8point(input + rank * 8, rank);
        }
        inline void fft_dif_32point(Complex* input, size_t rank = 1)
        {
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            static constexpr double cos_1_16 = 0.92387953251128675612818318939679;
            static constexpr double sin_1_16 = 0.3826834323650897717284599840304;
            static constexpr double cos_1_32 = 0.98078528040323044912618223613424;
            static constexpr double sin_1_32 = 0.19509032201612826784828486847702;
            static constexpr double cos_3_32 = 0.83146961230254523707878837761791;
            static constexpr double sin_3_32 = 0.55557023301960222474283081394853;
            static constexpr Complex w1(cos_1_32, -sin_1_32), w2(cos_1_16, -sin_1_16), w3(cos_3_32, -sin_3_32);
            static constexpr Complex w5(sin_3_32, -cos_3_32), w6(sin_1_16, -cos_1_16), w7(sin_1_32, -cos_1_32);
            static constexpr Complex w9(-sin_1_32, -cos_1_32), w10(-sin_1_16, -cos_1_16), w11(-sin_3_32, -cos_3_32);
            static constexpr Complex w13(-cos_3_32, -sin_3_32), w14(-cos_1_16, -sin_1_16), w15(-cos_1_32, -sin_1_32);

            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 16];
            Complex tmp3 = input[rank * 17];
            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 16] = tmp0 - tmp2;
            input[rank * 17] = (tmp1 - tmp3) * w1;

            tmp0 = input[rank * 2];
            tmp1 = input[rank * 3];
            tmp2 = input[rank * 18];
            tmp3 = input[rank * 19];
            input[rank * 2] = tmp0 + tmp2;
            input[rank * 3] = tmp1 + tmp3;
            input[rank * 18] = (tmp0 - tmp2) * w2;
            input[rank * 19] = (tmp1 - tmp3) * w3;

            tmp0 = input[rank * 4];
            tmp1 = input[rank * 5];
            tmp2 = input[rank * 20];
            tmp3 = input[rank * 21];
            fft_2point(tmp0, tmp2);
            tmp2 = cos_1_8 * Complex(tmp2.imag() + tmp2.real(), tmp2.imag() - tmp2.real());
            input[rank * 4] = tmp0;
            input[rank * 5] = tmp1 + tmp3;
            input[rank * 20] = tmp2;
            input[rank * 21] = (tmp1 - tmp3) * w5;

            tmp0 = input[rank * 6];
            tmp1 = input[rank * 7];
            tmp2 = input[rank * 22];
            tmp3 = input[rank * 23];
            input[rank * 6] = tmp0 + tmp2;
            input[rank * 7] = tmp1 + tmp3;
            input[rank * 22] = (tmp0 - tmp2) * w6;
            input[rank * 23] = (tmp1 - tmp3) * w7;

            tmp0 = input[rank * 8];
            tmp1 = input[rank * 9];
            tmp2 = input[rank * 24];
            tmp3 = input[rank * 25];
            fft_2point(tmp0, tmp2);
            tmp2 = Complex(tmp2.imag(), -tmp2.real());
            input[rank * 8] = tmp0;
            input[rank * 9] = tmp1 + tmp3;
            input[rank * 24] = tmp2;
            input[rank * 25] = (tmp1 - tmp3) * w9;

            tmp0 = input[rank * 10];
            tmp1 = input[rank * 11];
            tmp2 = input[rank * 26];
            tmp3 = input[rank * 27];
            input[rank * 10] = tmp0 + tmp2;
            input[rank * 11] = tmp1 + tmp3;
            input[rank * 26] = (tmp0 - tmp2) * w10;
            input[rank * 27] = (tmp1 - tmp3) * w11;

            tmp0 = input[rank * 12];
            tmp1 = input[rank * 13];
            tmp2 = input[rank * 28];
            tmp3 = input[rank * 29];
            fft_2point(tmp0, tmp2);
            tmp2 = -cos_1_8 * Complex(tmp2.real() - tmp2.imag(), tmp2.real() + tmp2.imag());
            input[rank * 12] = tmp0;
            input[rank * 13] = tmp1 + tmp3;
            input[rank * 28] = tmp2;
            input[rank * 29] = (tmp1 - tmp3) * w13;

            tmp0 = input[rank * 14];
            tmp1 = input[rank * 15];
            tmp2 = input[rank * 30];
            tmp3 = input[rank * 31];

            input[rank * 14] = tmp0 + tmp2;
            input[rank * 15] = tmp1 + tmp3;
            input[rank * 30] = (tmp0 - tmp2) * w14;
            input[rank * 31] = (tmp1 - tmp3) * w15;

            fft_dif_16point(input, rank);
            fft_dif_16point(input + rank * 16, rank);
        }
    
        // fft分裂基时间抽取蝶形变换
        inline void fft_split_radix_dit_butterfly(Complex omega, Complex omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2] * omega;
            Complex tmp3 = input[rank * 3] * omega_cube;

            fft_2point(tmp2, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 2] = tmp0 - tmp2;
            input[rank * 3] = tmp1 - tmp3;
        }
        // fft分裂基频率抽取蝶形变换
        inline void fft_split_radix_dif_butterfly(Complex omega, Complex omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(-tmp3.imag(), tmp3.real());

            input[0] = tmp0;
            input[rank] = tmp1;
            input[rank * 2] = (tmp2 - tmp3) * omega;
            input[rank * 3] = (tmp2 + tmp3) * omega_cube;
        }

        // 模板化时间抽取分裂基fft
        template <size_t LEN>
        void fft_split_radix_dit_template(Complex *input)
        {
            constexpr size_t log_len = hint_log2(LEN);
            constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
            fft_split_radix_dit_template<half_len>(input);
            fft_split_radix_dit_template<quarter_len>(input + half_len);
            fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
            for (size_t i = 0; i < quarter_len; i++)
            {
                Complex omega = TABLE.get_complex_conj(log_len, i);
                Complex omega_cube = TABLE.get_complex_conj(log_len, i * 3);
                fft_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
            }
        }
        template <>
        void fft_split_radix_dit_template<0>(Complex *input) {}
        template <>
        void fft_split_radix_dit_template<1>(Complex *input) {}
        template <>
        void fft_split_radix_dit_template<2>(Complex *input)
        {
            fft_2point(input[0], input[1]);
        }
        template <>
        void fft_split_radix_dit_template<4>(Complex *input)
        {
            fft_dit_4point(input, 1);
        }
        template <>
        void fft_split_radix_dit_template<8>(Complex *input)
        {
            fft_dit_8point(input, 1);
        }
        template <>
        void fft_split_radix_dit_template<16>(Complex *input)
        {
            // fft_radix2_dit_template<16>(input);
            fft_dit_16point(input, 1);
        }
        template <>
        void fft_split_radix_dit_template<32>(Complex *input)
        {
            // fft_radix2_dit_template<32>(input);
            fft_dit_32point(input, 1);
        }

        // 模板化频率抽取分裂基fft
        template <size_t LEN>
        void fft_split_radix_dif_template(Complex *input)
        {
            constexpr size_t log_len = hint_log2(LEN);
            constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
            for (size_t i = 0; i < quarter_len; i++)
            {
                Complex omega = TABLE.get_complex_conj(log_len, i);
                Complex omega_cube = TABLE.get_complex_conj(log_len, i * 3);
                fft_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
            }
            fft_split_radix_dif_template<half_len>(input);
            fft_split_radix_dif_template<quarter_len>(input + half_len);
            fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
        }
        template <>
        void fft_split_radix_dif_template<0>(Complex *input) {}
        template <>
        void fft_split_radix_dif_template<1>(Complex *input) {}
        template <>
        void fft_split_radix_dif_template<2>(Complex *input)
        {
            fft_2point(input[0], input[1]);
        }
        template <>
        void fft_split_radix_dif_template<4>(Complex *input)
        {
            fft_dif_4point(input, 1);
        }
        template <>
        void fft_split_radix_dif_template<8>(Complex *input)
        {
            fft_dif_8point(input, 1);
        }
        template <>
        void fft_split_radix_dif_template<16>(Complex *input)
        {
            // fft_radix2_dif_template<16>(input);
            fft_dif_16point(input, 1);
        }
        template <>
        void fft_split_radix_dif_template<32>(Complex *input)
        {
            // fft_radix2_dif_template<32>(input);
            fft_dif_32point(input, 1);
        }

        template <size_t LEN = 1>
        void fft_dit_template(Complex *input, size_t fft_len, bool bit_inv = true)
        {
            if (fft_len > LEN)
            {
                fft_dit_template<LEN * 2>(input, fft_len, bit_inv);
                return;
            }
            TABLE.expend(hint_log2(LEN));
            if (bit_inv)
            {
                binary_inverse_swap(input, LEN);
            }
            fft_split_radix_dit_template<LEN>(input);
        }
        template <>
        void fft_dit_template<1 << 24>(Complex *input, size_t fft_len, bool bit_inv) {}

        template <size_t LEN = 1>
        void fft_dif_template(Complex *input, size_t fft_len, bool bit_inv = true)
        {
            if (fft_len > LEN)
            {
                fft_dif_template<LEN * 2>(input, fft_len, bit_inv);
                return;
            }
            TABLE.expend(hint_log2(LEN));
            fft_split_radix_dif_template<LEN>(input);
            if (bit_inv)
            {
                binary_inverse_swap(input, LEN);
            }
        }
        template <>
        void fft_dif_template<1 << 24>(Complex *input, size_t fft_len, bool is_ifft) {}

        /// @brief 时间抽取基2fft
        /// @param input 复数组
        /// @param fft_len 数组长度
        /// @param bit_inv 是否逆序
        inline void fft_dit(Complex *input, size_t fft_len, bool bit_inv = true)
        {
            fft_len = max_2pow(fft_len);
            fft_dit_template<1>(input, fft_len, bit_inv);
        }

        /// @brief 频率抽取基2fft
        /// @param input 复数组
        /// @param fft_len 数组长度
        /// @param bit_inv 是否逆序
        inline void fft_dif(Complex *input, size_t fft_len, bool bit_inv = true)
        {
            fft_len = max_2pow(fft_len);
            fft_dif_template<1>(input, fft_len, bit_inv);
        }
    }
    std::string ui64to_string(UINT_64 input, UINT_8 digits)
    {
        std::string result(digits, '0');
        for (UINT_8 i = 0; i < digits; i++)
        {
            result[digits - i - 1] = static_cast<char>(input % 10 + '0');
            input /= 10;
        }
        return result;
    }
    constexpr UINT_64 stoui64(char *s, size_t dig = 4)
    {
        UINT_64 result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }
    static constexpr INT_64 DIGIT = 4;
    size_t char_to_real(char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            hint::UINT_64 tmp = hint::stoui64(buffer + pos - DIGIT, 4);
            comary[i].real(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_64 tmp = hint::stoui64(buffer, pos);
            comary[i].real(tmp);
        }
        return len;
    }
    size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            hint::UINT_64 tmp = hint::stoui64(buffer + pos - DIGIT, 4);
            comary[i].imag(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_64 tmp = hint::stoui64(buffer, pos);
            comary[i].imag(tmp);
        }
        return len;
    }
    void num_to_s(char *s, UINT_64 num)
    {
        char c = '0';
        std::tie(num, c) = div_mod<UINT_64>(num, 10);
        s[3] = c + '0';
        std::tie(num, c) = div_mod<UINT_64>(num, 10);
        s[2] = c + '0';
        std::tie(num, c) = div_mod<UINT_64>(num, 10);
        s[1] = c + '0';
        std::tie(num, c) = div_mod<UINT_64>(num, 10);
        s[0] = c + '0';
    }
    size_t read_string(char *s)
    {
        size_t len = 0;
        char c = getchar();
        // 处理多余回车或空格
        while (c == ' ' || c == '\n' || c == '\r')
        {
            c = getchar();
        }
        // 不断读入直到遇到回车或空格
        while (c != ' ' && c != '\n' && c != '\r')
        {
            s[len] = c;
            len++;
            c = getchar();
        }
        s[len] = '\0';
        return len;
    }
}

using namespace std;
using namespace hint;
using namespace hint_transform;
// char in[1 << 20];
char out[1 << 21 + 1];
Complex fft_ary[1 << 19];
int main()
{
    size_t len_a = 0, len_b = 0;
    scanf("%s", out);
    while (isdigit(out[len_a]))
    {
        len_a++;
    }
    if (len_a == 1 && out[0] == '0')
    {
        printf("0");
        return 0;
    }
    size_t len1 = char_to_real(out, fft_ary, len_a);
    scanf("%s", out);
    while (isdigit(out[len_b]))
    {
        len_b++;
    }
    if (len_b == 1 && out[0] == '0')
    {
        printf("0");
        return 0;
    }
    size_t len2 = char_to_imag(out, fft_ary, len_b);
    size_t fft_len = min_2pow(len1 + len2 - 1);
    fft_dif(fft_ary, fft_len, false); // 优化FFT
    for (size_t i = 0; i < fft_len; i++)
    {
        Complex tmp = fft_ary[i];
        fft_ary[i] = std::conj(tmp * tmp);
    }
    fft_dit(fft_ary, fft_len, false); // 优化FFT
    double inv = -0.5 / fft_len;
    UINT_64 carry = 0;
    size_t pos = 1 << 21;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        carry += UINT_64(fft_ary[i].imag() * inv + 0.5);
        UINT_64 num = 0;
        std::tie(carry, num) = div_mod<UINT_64>(carry, 10000);
        num_to_s(out + pos - 4, num);
        pos -= 4;
    }
    num_to_s(out + pos - 4, carry);
    pos -= 4;
    while (out[pos] == '0')
    {
        pos++;
    }
    puts(out + pos);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #11.635 ms4 MB + 220 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-15 16:07:54 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠