提交记录 28442


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 44.015 ms 16020 KB C++14 29.51 KB
提交时间 评测时间
2025-09-02 23:05:40 2025-09-02 23:05:43
#include <iostream>
#include <complex>
#include <vector>
#include <chrono>
#include <cmath>
#include <bitset>
#include <cstdint>
#include <cstddef>
#include <cassert>
#include <cstdlib>
#include <cstring>

namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    constexpr Float64 COS_PI_8 = 0.707106781186547524400844;
    constexpr size_t FFT_MAX_LEN = size_t(1) << 23;

    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    template <typename IntTy>
    constexpr bool is_2pow(IntTy n)
    {
        return n != 0 && (n & (n - 1)) == 0;
    }

    // 求整数的对数
    template <typename T>
    constexpr int hint_log2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r)
        {
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n)
            {
                r = mid;
            }
            else
            {
                l = mid;
            }
        }
        return l;
    }
    constexpr int hint_ctz(uint32_t x)
    {
        int r0 = 31;
        x &= (-x);
        if (x & 0x55555555)
        {
            r0 &= ~1;
        }
        if (x & 0x33333333)
        {
            r0 &= ~2;
        }
        if (x & 0x0F0F0F0F)
        {
            r0 &= ~4;
        }
        if (x & 0x00FF00FF)
        {
            r0 &= ~8;
        }
        if (x & 0x0000FFFF)
        {
            r0 &= ~16;
        }
        r0 += (x == 0);
        return r0;
    }

    constexpr int hint_ctz(uint64_t x)
    {
        int r0 = 63;
        x &= (-x);
        if (x & 0x5555555555555555)
        {
            r0 &= ~1; // -1
        }
        if (x & 0x3333333333333333)
        {
            r0 &= ~2; // -2
        }
        if (x & 0x0F0F0F0F0F0F0F0F)
        {
            r0 &= ~4; // -4
        }
        if (x & 0x00FF00FF00FF00FF)
        {
            r0 &= ~8; // -8
        }
        if (x & 0x0000FFFF0000FFFF)
        {
            r0 &= ~16; // -16
        }
        if (x & 0x00000000FFFFFFFF)
        {
            r0 &= ~32; // -32
        }
        r0 += (x == 0);
        return r0;
    }

    constexpr uint32_t bitrev(uint32_t n)
    {
        constexpr uint32_t mask55 = 0x55555555;
        constexpr uint32_t mask33 = 0x33333333;
        constexpr uint32_t mask0f = 0x0f0f0f0f;
        constexpr uint32_t maskff = 0x00ff00ff;
        n = ((n & mask55) << 1) | ((n >> 1) & mask55);
        n = ((n & mask33) << 2) | ((n >> 2) & mask33);
        n = ((n & mask0f) << 4) | ((n >> 4) & mask0f);
        n = ((n & maskff) << 8) | ((n >> 8) & maskff);
        return (n << 16) | (n >> 16);
    }

    constexpr uint64_t bitrev(uint64_t n)
    {
        constexpr uint64_t mask5555 = 0x5555555555555555;
        constexpr uint64_t mask3333 = 0x3333333333333333;
        constexpr uint64_t mask0f0f = 0x0f0f0f0f0f0f0f0f;
        constexpr uint64_t mask00ff = 0x00ff00ff00ff00ff;
        constexpr uint64_t maskffff = 0x0000ffff0000ffff;
        n = ((n & mask5555) << 1) | ((n >> 1) & mask5555);
        n = ((n & mask3333) << 2) | ((n >> 2) & mask3333);
        n = ((n & mask0f0f) << 4) | ((n >> 4) & mask0f0f);
        n = ((n & mask00ff) << 8) | ((n >> 8) & mask00ff);
        n = ((n & maskffff) << 16) | ((n >> 16) & maskffff);
        return (n << 32) | (n >> 32);
    }

    constexpr size_t bitrev(size_t n, int bits)
    {
        if (bits < 32)
        {
            return bitrev(static_cast<uint32_t>(n)) >> (32 - bits);
        }
        else
        {
            return bitrev(static_cast<uint64_t>(n)) >> (64 - bits);
        }
    }

    constexpr int hint_popcnt(uint32_t n)
    {
        return __builtin_popcount(n);
    }

    template <typename T, T N>
    struct StaticObject
    {
        using Type = T;
        static constexpr Type value = N;
    };

    template <size_t N>
    using StaticSize = StaticObject<size_t, N>;

    template <int N>
    using StaticInt = StaticObject<int, N>;

} // namespace hint
template <typename T>
inline void transform2(T &sum, T &diff)
{
    T temp0 = sum, temp1 = diff;
    sum = temp0 + temp1;
    diff = temp0 - temp1;
}

template <typename T>
inline void transform2(const T a, const T b, T &sum, T &diff)
{
    sum = a + b;
    diff = a - b;
}
using namespace hint;

template <typename Float>
void twiddle_omega(size_t rank, size_t indx, std::complex<Float> &omega)
{
    omega = std::polar<Float>(1.0, -HINT_2PI * indx / rank);
}

template <typename Float>
void init_table(std::vector<std::complex<Float>> &table, size_t rank)
{
    table.resize(rank / 2);
    for (size_t i = 0; i < table.size(); i++)
    {
        twiddle_omega(rank, i, table[i]);
    }
}

template <typename Float>
void init_table_rev(std::vector<std::complex<Float>> &table, size_t rank)
{
    table.resize(rank / 2);
    int bits = hint_log2(rank);
    for (size_t i = 0; i < table.size(); i++)
    {
        size_t rev_indx = bitrev(i, bits - 1);
        twiddle_omega(rank, rev_indx, table[i]);
    }
}

template <typename Float>
inline void idit(std::complex<Float> in_out[], size_t len, bool norm = false)
{
    using Complex = std::complex<Float>;
    std::vector<std::complex<Float>> table;
    for (size_t rank = 2; rank <= len; rank *= 2)
    {
        init_table(table, rank);
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
        {
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i] * std::conj(table[i]);
                it0[i] = x + y;
                it1[i] = x - y;
            }
        }
    }
    if (norm)
    {
        const double len_inv = 1.0 / len;
        for (size_t i = 0; i < len; i++)
        {
            in_out[i] *= len_inv;
        }
    }
}

template <typename Float>
inline void dif(std::complex<Float> in_out[], size_t len)
{
    using Complex = std::complex<Float>;
    std::vector<Complex> table;
    for (size_t rank = len; rank >= 2; rank /= 2)
    {
        init_table(table, rank);
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
        {
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i];
                it0[i] = x + y;
                it1[i] = (x - y) * table[i];
            }
        }
    }
}

template <typename Float>
inline void idit_rtable(std::complex<Float> in_out[], size_t len, bool norm = false)
{
    using Complex = std::complex<Float>;
    std::vector<Complex> table;
    init_table_rev(table, len);
    for (size_t rank = 2; rank <= len; rank *= 2)
    {
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        auto table_it = table.begin();
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank, table_it++)
        {
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i];
                it0[i] = x + y;
                it1[i] = (x - y) * std::conj(table_it[0]);
            }
        }
    }
    if (norm)
    {
        const double len_inv = 1.0 / len;
        for (size_t i = 0; i < len; i++)
        {
            in_out[i] *= len_inv;
        }
    }
}

template <typename Float>
inline void dif_rtable(std::complex<Float> in_out[], size_t len)
{
    using Complex = std::complex<Float>;
    std::vector<std::complex<Float>> table;
    init_table_rev(table, len);
    for (size_t rank = len; rank >= 2; rank /= 2)
    {
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        auto table_it = table.begin();
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank, table_it++)
        {
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i] * table_it[0];
                it0[i] = x + y;
                it1[i] = x - y;
            }
        }
    }
}

template <typename FloatTy>
class BinRevTableComplexIterHP
{
public:
    using Complex = std::complex<FloatTy>;
    static constexpr int MAX_LOG_LEN = 32;
    static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;

    BinRevTableComplexIterHP(int log_max_iter_in, int log_fft_len_in)
        : index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
    {
        assert(log_max_iter <= log_fft_len);
        assert(log_fft_len <= MAX_LOG_LEN);
        const FloatTy factor = FloatTy(1) / (size_t(1) << (log_fft_len - log_max_iter));
        for (int i = 0; i < MAX_LOG_LEN; i++)
        {
            units[i] = getOmega(size_t(1) << (i + 1), 1, factor);
        }
        table[0] = Complex(1, 0);
    }
    void reset(size_t i = 0)
    {
        index = i;
        if (i == 0)
        {
            pop = 0;
            return;
        }
        pop = hint_popcnt(i);
        const size_t len = size_t(1) << log_fft_len;
        for (int p = pop; p > 0; p--)
        {
            table[p] = getOmega(len, bitrev(i, log_max_iter));
            i &= (i - 1);
        }
    }
    Complex iterate()
    {
        Complex res = table[pop];
        index++;
        int zero = hint_ctz(index);
        pop -= zero;
        table[pop + 1] = table[pop] * units[zero];
        pop++;
        return res;
    }

    static Complex getOmega(size_t n, size_t index, FloatTy factor = 1)
    {
        FloatTy theta = -HINT_2PI * index / n;
        return std::polar<FloatTy>(1, theta * factor);
    }

private:
    Complex units[MAX_LOG_LEN]{};
    Complex table[MAX_LOG_LEN]{};
    size_t index;
    int pop;
    int log_max_iter, log_fft_len;
};

template <typename Float>
inline void idit_qtable(std::complex<Float> in_out[], size_t len, bool norm = false)
{
    using Complex = std::complex<Float>;
    BinRevTableComplexIterHP<Float> table(31, 32);
    for (size_t rank = 2; rank <= len; rank *= 2)
    {
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        table.reset(0);
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
        {
            Complex omega = table.iterate();
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i];
                it0[i] = x + y;
                it1[i] = (x - y) * std::conj(omega);
            }
        }
    }
    if (norm)
    {
        const double len_inv = 1.0 / len;
        for (size_t i = 0; i < len; i++)
        {
            in_out[i] *= len_inv;
        }
    }
}

template <typename Float>
inline void dif_qtable(std::complex<Float> in_out[], size_t len)
{
    using Complex = std::complex<Float>;
    BinRevTableComplexIterHP<Float> table(31, 32);
    for (size_t rank = len; rank >= 2; rank /= 2)
    {
        size_t stride = rank / 2;
        auto it0 = in_out, it1 = it0 + stride;
        table.reset(0);
        for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
        {
            Complex omega = table.iterate();
            for (size_t i = 0; i < stride; i++)
            {
                Complex x = it0[i];
                Complex y = it1[i] * omega;
                it0[i] = x + y;
                it1[i] = x - y;
            }
        }
    }
}

template <typename Float>
inline void idit(Float in_out[], size_t len, bool norm = false)
{
    using Complex = std::complex<Float>;
    auto in_out_c = reinterpret_cast<Complex *>(in_out);
    idit_qtable(in_out_c, len / 2, norm);
}

template <typename Float>
inline void dif(Float in_out[], size_t len)
{
    using Complex = std::complex<Float>;
    auto in_out_c = reinterpret_cast<Complex *>(in_out);
    dif_qtable(in_out_c, len / 2);
}

template <typename Float>
void real_conv_std(Float in_out1[], Float in2[], size_t len)
{
    using Complex = std::complex<Float>;
    assert(is_2pow(len));
    std::vector<Complex> arr1(len), arr2(len);
    std::copy(in_out1, in_out1 + len, arr1.begin());
    std::copy(in2, in2 + len, arr2.begin());
    dif(arr1.data(), len);
    dif(arr2.data(), len);
    for (size_t i = 0; i < len; i++)
    {
        arr1[i] *= arr2[i];
    }
    idit(arr1.data(), len, true);
    for (size_t i = 0; i < len; i++)
    {
        in_out1[i] = arr1[i].real();
    }
}

template <typename Float>
void real_conv_rev_table(Float in_out1[], Float in2[], size_t len)
{
    using Complex = std::complex<Float>;
    assert(is_2pow(len));
    std::vector<Complex> arr1(len), arr2(len);
    std::copy(in_out1, in_out1 + len, arr1.begin());
    std::copy(in2, in2 + len, arr2.begin());
    dif_qtable(arr1.data(), len);
    dif_qtable(arr2.data(), len);
    for (size_t i = 0; i < len; i++)
    {
        arr1[i] *= arr2[i];
    }
    idit_qtable(arr1.data(), len, true);
    for (size_t i = 0; i < len; i++)
    {
        in_out1[i] = arr1[i].real();
    }
}

template <size_t RI_DIFF = 1, typename FloatTy>
inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1,
                     const std::complex<FloatTy> &omega0, const FloatTy factor = 0.125)
{
    using Complex = std::complex<FloatTy>;
    auto combine2 = [&omega0](auto r0, auto i0, auto r1, auto i1, Complex &out0, Complex &out1)
    {
        auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
        auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

        r0 = ti1 * omega0.real() + tr1 * omega0.imag();
        i0 = ti1 * omega0.imag() - tr1 * omega0.real();

        out0.real(tr0 + r0);
        out0.imag(ti0 + i0);

        out1.real(tr0 - r0);
        out1.imag(i0 - ti0);
    };
    Complex x0, x1, x2, x3;
    auto r0 = inout0[0], i0 = inout0[RI_DIFF], r1 = inout1[0], i1 = inout1[RI_DIFF];
    combine2(r0, i0, r1, i1, x0, x1);
    r0 = in0[0], i0 = in0[RI_DIFF], r1 = in1[0], i1 = in1[RI_DIFF];
    combine2(r0, i0, r1, i1, x2, x3);
    x0 *= x2;
    x1 *= x3;
    { // separate2
        r0 = x0.real(), i0 = x0.imag(), r1 = x1.real(), i1 = x1.imag();
        auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
        auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

        auto r = tr1 * omega0.imag() - ti1 * omega0.real();
        auto i = tr1 * omega0.real() + ti1 * omega0.imag();

        r0 = tr0 + r;
        i0 = ti0 + i;

        r1 = tr0 - r;
        i1 = i - ti0;
    }
    inout0[0] = r0 * factor, inout0[RI_DIFF] = i0 * factor, inout1[0] = r1 * factor, inout1[RI_DIFF] = i1 * factor;
}
template <typename Float>
inline void real_dot_binrev(Float in_out[], const Float in[], size_t float_len)
{
    assert(is_2pow(float_len));
    using Complex = std::complex<Float>;
    Float inv = 2.0 / float_len;
    {
        auto r0 = in_out[0], i0 = in_out[1], r1 = in[0], i1 = in[1];
        transform2(r0, i0);
        transform2(r1, i1);
        r0 *= r1, i0 *= i1;
        transform2(r0, i0);
        in_out[0] = r0 * 0.5 * inv, in_out[1] = i0 * 0.5 * inv;
    }
    auto temp = Complex(in_out[2], in_out[3]) * Complex(in[2], in[3]) * inv;
    in_out[2] = temp.real(), in_out[3] = temp.imag();
    inv /= 8;
    dot_rfft(&in_out[4], &in_out[6], &in[4], &in[6], Complex(COS_PI_8, -COS_PI_8), inv);
    BinRevTableComplexIterHP<Float> table(31, 32);
    for (size_t begin = 8; begin < float_len; begin *= 2)
    {
        table.reset(begin / 2);
        auto it0 = in_out + begin, it1 = it0 + begin - 2, it2 = in + begin, it3 = it2 + begin - 2;
        for (; it0 < it1; it0 += 2, it1 -= 2, it2 += 2, it3 -= 2)
        {
            auto omega = table.iterate();
            dot_rfft(it0, it1, it2, it3, omega, inv);
        }
    }
}

inline void real_conv_binrev(Complex64 in_out[], Complex64 in[], size_t len)
{
    std::cout << len << "\n";
    static BinRevTableComplexIterHP<double> table(31, 32);
    auto t0 = in_out[0], t1 = in[0];
    auto t2 = (t0.real() + t0.imag()) * (t1.real() + t1.imag());
    auto t3 = (t0.real() - t0.imag()) * (t1.real() - t1.imag());
    in_out[0] = Complex64(t2 + t3, t2 - t3) * 0.5;
    in_out[1] *= in[1];
    dot_rfft((double *)&in_out[2], (double *)&in_out[3], (double *)&in[2], (double *)&in[3], Complex64(COS_PI_8, -COS_PI_8));
    for (size_t begin = 4; begin < len; begin *= 2)
    {
        table.reset(begin);
        // std::cout << begin << '\n';
        auto it1 = in_out + begin, it2 = in + begin;
        for (size_t i = 0; i < begin / 2; i++)
        {
            Complex64 omega = table.iterate();
            // std::cout << omega;
            dot_rfft((double *)&it1[i], (double *)&it1[begin - i - 1], (double *)&it2[i], (double *)&it2[begin - i - 1], omega);
        }
    }
}

void real_conv_rfft(Float64 in_out1[], Float64 in2[], size_t len)
{
    assert(is_2pow(len));
    dif(in_out1, len);
    dif(in2, len);
    real_dot_binrev(in_out1, in2, len);
    idit(in_out1, len);
}
template <typename T>
std::vector<T> poly_multiply(const std::vector<T> &in1, const std::vector<T> &in2)
{
    size_t len1 = in1.size(), len2 = in2.size();
    size_t conv_len = len1 + len2;
    size_t float_len = int_ceil2(conv_len);
    size_t fft_len = float_len / 2;
    auto p1 = (double *)malloc(float_len * sizeof(double));
    auto p2 = (double *)malloc(float_len * sizeof(double));
    std::copy(in1.begin(), in1.end(), p1);
    std::copy(in2.begin(), in2.end(), p2);
    std::fill(p1 + len1, p1 + float_len, 0);
    std::fill(p2 + len2, p2 + float_len, 0);
    // real_conv<false>(p1, p2, float_len);
    // real_conv_std(p1, p2, float_len);
    // real_conv_rev_table(p1, p2, float_len);
    real_conv_rfft(p1, p2, float_len);
    // auto i64_p = reinterpret_cast<uint64_t *>(p1);
    std::vector<T> res(conv_len);
    for (size_t i = 0; i < conv_len; i++)
    {
        res[i] = p1[i] + 0.5;
    }
    free(p1);
    free(p2);
    return res;
}
template <typename T>
void result_test(const std::vector<T> &res, uint64_t ele1, uint64_t ele2)
{
    size_t len = res.size();
    for (size_t i = 0; i < len / 2; i++)
    {
        uint64_t x = (i + 1) * ele1 * ele2;
        uint64_t y = res[i];
        if (x != y)
        {
            std::cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
            return;
        }
    }
    for (size_t i = len / 2; i < len; i++)
    {
        uint64_t x = (len - i - 1) * ele1 * ele2;
        uint64_t y = res[i];
        if (x != y)
        {
            std::cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
            return;
        }
    }
    std::cout << "success\n";
}
void perf_conv()
{
    int n = 13;
    // std::cin >> n;
    size_t len = size_t(1) << n; // 变换长度
    std::cout << "conv len:" << len << "\n";
    uint64_t ele1 = 9999, ele2 = 7777;
    std::vector<uint64_t> in1(len / 2, 0);
    std::vector<uint64_t> in2(len / 2, 0); // 计算两个长度为len/2,每个元素为ele的卷积
    for (size_t i = 0; i < 2500; i++)
    {
        in1[i] = ele1;
        in2[i] = ele2;
    }
    auto t1 = std::chrono::high_resolution_clock::now();
    std::vector<uint64_t> res = poly_multiply(in1, in2);
    auto t2 = std::chrono::high_resolution_clock::now();
    for (auto i : res)
    {
        std::cout << i << " ";
    }
    std::cout << "\n";
    result_test<uint64_t>(res, ele1, ele2); // 结果校验
    std::cout << "cost:" << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us\n";
}

void test_rev_tb()
{
    BinRevTableComplexIterHP<double> table(31, 32);
    using cpx = std::complex<double>;
    auto rev_cpx = [](size_t i, int div)
    {
        const int bits = 31;
        const size_t len = size_t(1) << bits;
        cpx ret;
        twiddle_omega(len, bitrev(i, bits) / div, ret);
        return ret;
    };
    size_t end = 1 << 23, begin = 0;
    std::cin >> begin;
    table.reset(begin);
    for (size_t i = begin; i < end; i++)
    {
        cpx a = rev_cpx(i, 2), b = table.iterate();
        if (std::abs(a - b) > 1e-15)
        {
            std::cout << a << b << '\n';
            return;
        }
    }
    std::cout << "end";
}
inline void test_rconv()
{
    using Complex = Complex64;
    static double ary1[1 << 23]{};
    static double ary2[1 << 23]{};
    size_t n = 1 << 23;
    // real_conv_binrev(ary1, ary2, n);
    for (size_t i = 0; i < n / 2; i++)
    {
        ary1[i] = 9999;
        ary2[i] = 9999;
    }

    dif(ary1, n);
    dif(ary2, n);
    auto t1 = std::chrono::high_resolution_clock::now();
    // real_conv_binrev4((double *)ary1, (double *)ary2, n / 2);
    real_conv_binrev((Complex *)ary1, (Complex *)ary2, n / 2);
    auto t2 = std::chrono::high_resolution_clock::now();
    for (size_t i = 0; i < n; i++)
    {
        ary1[i] = ary1[i] * (0.25 / n);
    }
    idit(ary1, n, false);
    for (size_t i = 0; i < n; i++)
    {
        std::cout << uint64_t((ary1)[i] + 0.5) / (9999ull * 9999) << "\t";
    }
    std::cout << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us" << std::endl;
}

inline void test_rconv1()
{
    using Complex = Complex64;

    static Complex ary1[1 << 23]{};
    static Complex ary2[1 << 23]{};
    size_t n = 1 << 23;
    // real_conv_binrev(ary1, ary2, n);
    for (size_t i = 0; i < n / 2; i++)
    {
        ((double *)ary1)[i] = 9999;
        ((double *)ary2)[i] = 9999;
    }

    dif(ary1, n / 2);
    dif(ary2, n / 2);
    // for (size_t i = 0; i < n / 2; i++)
    // {
    //     std::cout << ary1[i] << ' ' << ary2[i] << std::endl;
    // }
    // for (size_t i = 0; i < n / 2; i += 4)
    // {
    //     using C64X4 = hint_simd::Complex64X4;
    //     auto p1 = (double *)(ary1 + i), p2 = (double *)(ary2 + i);
    //     C64X4 a(p1), b(p2);
    //     a = a.toRRIIPermu();
    //     b = b.toRRIIPermu();
    //     a.storeu(p1);
    //     b.storeu(p2);
    // }
    auto t1 = std::chrono::high_resolution_clock::now();
    // real_conv_binrev4((double *)ary1, (double *)ary2, n / 2);
    real_conv_binrev(ary1, ary2, n / 2);
    auto t2 = std::chrono::high_resolution_clock::now();
    // for (size_t i = 0; i < n / 2; i += 4)
    // {
    //     using C64X4 = hint_simd::Complex64X4;
    //     auto p1 = (double *)(ary1 + i), p2 = (double *)(ary2 + i);
    //     C64X4 a(p1), b(p2);
    //     a = a.toRIRIPermu();
    //     b = b.toRIRIPermu();
    //     a.storeu(p1);
    //     b.storeu(p2);
    // }
    // real_conv_binrev(ary1, ary2, n / 2);
    // for (size_t i = 0; i < n / 2; i++)
    // {
    //     std::cout << ary1[i] << ' ' << ary2[i] << std::endl;
    // }
    for (size_t i = 0; i < n / 2; i++)
    {
        ary1[i] = ary1[i] * (0.25 / n);
    }
    idit(ary1, n / 2, false);
    for (size_t i = 0; i < n; i++)
    {
        std::cout << uint64_t(((double *)ary1)[i] + 0.5) / (9999ull * 9999) << "\t";
    }
    std::cout << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us" << std::endl;
    // fft(ary1, n);
    // for (size_t i = 0; i < n; i++)
    // {
    //     ary1[i] *= ary1[i];
    // }
    // ifft(ary1, n);
    // for (size_t i = 0; i < n; i++)
    // {
    //     ((double *)ary1)[i] = ary1[i].real();
    // }
    // // fft(ary1, n / 2);
    // fft(ary2, n / 2);
    // real_conv(ary2, ary2, n / 2);
    // ifft(ary2, n / 2);
    // for (size_t i = 0; i < n; i++)
    // {
    //     // if (abs(ary1[i] - ary2[i]) > 0.5)
    //     {
    //         std::cout << i << " " << ((double *)ary1)[i] << " " << ((double *)ary2)[i] << endl;
    //         // break;
    //     }
    // }
}

void test_rdot()
{
    using Complex = Complex64;
    static double ary1[1 << 23]{};
    static double ary2[1 << 23]{};
    static double cary1[1 << 23]{};
    static double cary2[1 << 23]{};

    size_t n = 1 << 5;
    srand(0);
    for (size_t i = 0; i < n; i++)
    {
        ary1[i] = rand() % 10;
        ary2[i] = rand() % 10;
        cary1[i] = ary1[i];
        cary2[i] = ary2[i];
    }
    real_conv_std(ary1, ary2, n);
    dif(ary1, n);

    dif(cary1, n);
    dif(cary2, n);
    real_dot_binrev(cary1, cary2, n);
    // real_conv_binrev((Complex *)cary1, (Complex *)cary2, n / 2);
    auto cp1 = reinterpret_cast<Complex *>(cary1), cp2 = reinterpret_cast<Complex *>(ary1);
    for (size_t i = 0; i < n / 2; i++)
    {
        if (std::abs(cp1[i] - cp2[i]) > 1e-10)
        {
            std::cout << i << "\t" << cp1[i] << " " << cp2[i] << std::endl;
        }
    }
    std::cout << "end\n";
}

class ItoStrBase10000
{
private:
    uint32_t table[10000]{};

public:
    static constexpr uint32_t itosbase10000(uint32_t num)
    {
        uint32_t res = (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                       ((num / 10 % 10) << 16) | ((num % 10) << 24);
        return res + '0' * 0x1010101;
    }
    constexpr ItoStrBase10000()
    {
        for (size_t i = 0; i < 10000; i++)
        {
            table[i] = itosbase10000(i);
        }
    }
    void tostr(char *str, uint32_t num) const
    {
        // *reinterpret_cast<uint32_t *>(str) = table[num];
        std::memcpy(str, &table[num], sizeof(num));
    }
    uint32_t tostr(uint32_t num) const
    {
        return table[num];
    }
};

class StrtoIBase100
{
private:
    static constexpr size_t TABLE_SIZE = size_t(1) << 15;
    uint16_t table[TABLE_SIZE]{};

public:
    static constexpr uint16_t itosbase100(uint16_t num)
    {
        uint16_t res = (num / 10 % 10) | ((num % 10) << 8);
        return res + '0' * 0x0101;
    }
    constexpr StrtoIBase100()
    {
        for (size_t i = 0; i < TABLE_SIZE; i++)
        {
            table[i] = UINT16_MAX;
        }
        for (size_t i = 0; i < 100; i++)
        {
            table[itosbase100(i)] = i;
        }
    }
    uint16_t toInt(const char *str) const
    {
        uint16_t num;
        std::memcpy(&num, str, sizeof(num));
        return table[num];
    }
};

constexpr ItoStrBase10000 itosbase10000{};
constexpr StrtoIBase100 strtoibase100{};

constexpr uint32_t stobase10000(const char *s)
{
    return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}

#include <immintrin.h>
template <typename T, size_t ALIGN = 64>
class AlignMem
{
public:
    using Ptr = T *;
    using ConstPtr = const T *;
    AlignMem() : ptr(nullptr) {}
    AlignMem(size_t n) : len(n), ptr(reinterpret_cast<Ptr>(_mm_malloc(n * sizeof(T), ALIGN))) {}
    ~AlignMem()
    {
        if (ptr)
        {
            _mm_free(ptr);
        }
    };
    T &operator[](size_t i)
    {
        return ptr[i];
    }
    const T &operator[](size_t i) const
    {
        return ptr[i];
    }
    Ptr begin()
    {
        return ptr;
    }
    Ptr end()
    {
        return ptr + len;
    }
    ConstPtr begin() const
    {
        return ptr;
    }
    ConstPtr end() const
    {
        return ptr + len;
    }

private:
    T *ptr;
    size_t len;
};

template <typename T>
void fill_zero(T *begin, T *end)
{
    std::memset(begin, 0, (end - begin) * sizeof(T));
}

template <typename T>
size_t str_num_to_array_base10000(const char *str, size_t len, T *ary)
{
    constexpr size_t BLOCK = 4;
    auto end = str + len, p = str;
    size_t i = 0;
    for (auto ed = end - len % BLOCK; p < ed; p += BLOCK, i++)
    {
        ary[i] = stobase10000(p);
    }
    size_t shift = 0;
    if (p < end)
    {
        size_t rem = end - p;
        int n = 0;
        for (; p < end; p++)
        {
            n = n * 10 + *p - '0';
        }
        shift = BLOCK - rem;
        for (; rem < BLOCK; rem++)
        {
            n *= 10;
        }
        ary[i] = n;
        i++;
    }
    return shift;
}

template <typename T>
size_t conv_to_str_base10000(const T *ary, size_t conv_len, size_t shift, char *res, size_t &res_len)
{
    constexpr size_t BLOCK = 4, BASE = 10000;
    res_len = (conv_len + 1) * BLOCK;
    auto end = res + res_len;
    size_t i = conv_len;
    uint64_t carry = 0;
    while (i > 0)
    {
        i--;
        end -= BLOCK;
        carry += uint64_t(ary[i] + 0.5);
        itosbase10000.tostr(end, carry % BASE);
        carry /= BASE;
    }
    assert(carry < BASE);
    end -= 4;
    itosbase10000.tostr(end, carry);
    while (*end == '0')
    {
        end++;
    }
    size_t offset = end - res;
    res_len -= (offset + shift);
    return offset;
}

// return result begin
char *big_mul(const char *str1, size_t len1, const char *str2, size_t len2, char *res, size_t &res_len)
{
    constexpr size_t BLOCK = 4, BASE = 10000;
    size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
    size_t conv_len = block_len1 + block_len2 - 1, fft_len = hint::int_ceil2(conv_len);
    AlignMem<Float64> ary1(fft_len), ary2(fft_len);
    size_t shift = str_num_to_array_base10000(str1, len1, &ary1[0]);
    shift += str_num_to_array_base10000(str2, len2, &ary2[0]);
    fill_zero(ary1.begin() + block_len1, ary1.end());
    fill_zero(ary2.begin() + block_len2, ary2.end());
    real_conv_rfft(ary1.begin(), ary2.begin(), fft_len);
    return res + conv_to_str_base10000(ary1.begin(), conv_len, shift, res, res_len);
}

size_t preserve_strlen(size_t len1, size_t len2)
{
    constexpr size_t BLOCK = 4;
    size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
    return (block_len1 + block_len2) * BLOCK;
}

void mul_test()
{
    std::string s1, s2;
    std::cin >> s1 >> s2;
    size_t len1 = s1.size(), len2 = s2.size();
    size_t res_len = preserve_strlen(len1, len2);
    std::vector<char> res(res_len, '0');
    auto begin = big_mul(s1.data(), len1, s2.data(), len2, res.data(), res_len);
    auto end = begin + res_len;
    fwrite(begin, 1, res_len, stdout);
}


int main()
{
    mul_test();
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #144.015 ms15 MB + 660 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-18 21:51:32 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠