提交记录 21166


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002. 测测你的多项式乘法 Accepted 100 44.4 ms 40636 KB C++14 25.48 KB
提交时间 评测时间
2024-02-06 18:52:00 2024-02-06 18:52:22
#include <tuple>
#include <iostream>
#include <cstdint>
#include <cstring>
#include <cmath>
#include <complex>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("avx2")
#pragma GCC target("fma")

namespace hint_simd
{
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(4096) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        constexpr size_t size() const
        {
            return LEN;
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };
    // Use AVX
    // 256bit simd
    // 4个Double并行
    struct DoubleX4
    {
        __m256d data;
        DoubleX4()
        {
            data = _mm256_setzero_pd();
        }
        DoubleX4(double input)
        {
            data = _mm256_set1_pd(input);
        }
        DoubleX4(__m256d input)
        {
            data = input;
        }
        DoubleX4(const DoubleX4 &input)
        {
            data = input.data;
        }
        // 从连续的数组构造
        DoubleX4(double const *ptr)
        {
            loadu(ptr);
        }
        // 用4个数构造
        DoubleX4(double a3, double a2, double a1, double a0)
        {
            data = _mm256_set_pd(a3, a2, a1, a0);
        }
        void clr()
        {
            data = _mm256_setzero_pd();
        }
        void load(double const *ptr)
        {
            data = _mm256_load_pd(ptr);
        }
        void loadu(double const *ptr)
        {
            data = _mm256_loadu_pd(ptr);
        }
        void store(double *ptr) const
        {
            _mm256_store_pd(ptr, data);
        }
        void storeu(double *ptr) const
        {
            _mm256_storeu_pd(ptr, data);
        }
        void operator<<(double *ptr)
        {
            data = _mm256_loadu_pd(ptr);
        }
        void print() const
        {
            double ary[4];
            store(ary);
            printf("(%lf,%lf,%lf,%lf)\n",
                   ary[0], ary[1], ary[2], ary[3]);
        }
        template <int N>
        DoubleX4 permute4x64() const
        {
            return _mm256_permute4x64_pd(data, N);
        }
        template <int N>
        DoubleX4 permute() const
        {
            return _mm256_permute_pd(data, N);
        }
        DoubleX4 reverse() const
        {
            return permute4x64<0b00011011>();
        }
        DoubleX4 addsub(DoubleX4 input) const
        {
            return _mm256_addsub_pd(data, input.data);
        }
        DoubleX4 fmadd(DoubleX4 mul1, DoubleX4 mul2) const
        {
            return _mm256_fmadd_pd(mul1.data, mul2.data, data);
        }
        DoubleX4 fmsub(DoubleX4 mul1, DoubleX4 mul2) const
        {
            return _mm256_fmsub_pd(mul1.data, mul2.data, data);
        }
        DoubleX4 operator+(DoubleX4 input) const
        {
            return _mm256_add_pd(data, input.data);
        }
        DoubleX4 operator-(DoubleX4 input) const
        {
            return _mm256_sub_pd(data, input.data);
        }
        DoubleX4 operator*(DoubleX4 input) const
        {
            return _mm256_mul_pd(data, input.data);
        }
        DoubleX4 operator/(DoubleX4 input) const
        {
            return _mm256_div_pd(data, input.data);
        }
        DoubleX4 operator-() const
        {
            return _mm256_sub_pd(_mm256_setzero_pd(), data);
        }
    };
}

#endif

#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>

namespace hint
{
    using namespace hint_simd;
    using Float32 = float;
    using Float64 = double;

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    constexpr size_t FHT_MAX_LEN = size_t(1) << 21;

    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    // 求整数的对数
    template <typename T>
    constexpr int hint_log2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r)
        {
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n)
            {
                r = mid;
            }
            else
            {
                l = mid;
            }
        }
        return l;
    }
    template <typename T>
    constexpr std::pair<T, T> div_mod(T dividend, T divisor)
    {
        return std::make_pair(dividend / divisor, dividend % divisor);
    }
    namespace hint_transform
    {
        template <typename T>
        inline void transform_2point(T &sum, T &diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }
        namespace hint_fht
        {
            template <size_t LEN, typename FloatTy>
            struct FHT
            {
                enum
                {
                    fht_len = LEN,
                    half_len = LEN / 2,
                    quarter_len = LEN / 4,
                    log_len = hint_log2(fht_len)
                };
                using HalfFHT = FHT<half_len, FloatTy>;
                static FloatTy cos_omega(size_t i)
                {
                    return std::cos(i * HINT_2PI / fht_len);
                }
                static FloatTy sin_omega(size_t i)
                {
                    return std::sin(i * HINT_2PI / fht_len);
                }
                static std::complex<FloatTy> comp_omega(size_t i)
                {
                    return std::polar<FloatTy>(1.0, i * HINT_2PI / fht_len);
                }
                template <typename FloatIt>
                static void dit(FloatIt in_out)
                {
                    static_assert(std::is_same<typename std::iterator_traits<FloatIt>::value_type, FloatTy>::value, "Must be same as the FHT template float type");

                    HalfFHT::dit(in_out + half_len);
                    HalfFHT::dit(in_out);
                    transform_2point(in_out[0], in_out[half_len]);
                    transform_2point(in_out[quarter_len], in_out[half_len + quarter_len]);
                    static FloatTy cos_arr[8] = {1, cos_omega(1), cos_omega(2), cos_omega(3), cos_omega(4), cos_omega(5), cos_omega(6), cos_omega(7)};
                    static FloatTy sin_arr[8] = {0, sin_omega(1), sin_omega(2), sin_omega(3), sin_omega(4), sin_omega(5), sin_omega(6), sin_omega(7)};
                    auto it0 = in_out + 1, it1 = in_out + half_len - 1;
                    auto it2 = it0 + half_len, it3 = it1 + half_len;
                    for (; it0 < in_out + 4; ++it0, --it1, ++it2, --it3)
                    {
                        auto c = cos_arr[it0 - in_out], s = sin_arr[it0 - in_out];
                        auto temp0 = it2[0], temp1 = it3[0];
                        auto temp2 = temp0 * c + temp1 * s; //+*+ -(-*-)
                        auto temp3 = temp0 * s - temp1 * c; //-(+)*- -*-(+)
                        temp0 = it0[0], temp1 = it1[0];
                        it0[0] = temp0 + temp2; //+
                        it1[0] = temp1 + temp3; //-
                        it2[0] = temp0 - temp2; //+
                        it3[0] = temp1 - temp3; //-
                    }
                    it1 -= 3, it3 -= 3;
                    static const DoubleX4 cos_unit(cos_arr[4]), sin_unit(sin_arr[4]);
                    DoubleX4 c4(cos_arr + 4), s4(sin_arr + 4);
                    for (; it0 < in_out + quarter_len; it0 += 4, it1 -= 4, it2 += 4, it3 -= 4)
                    {
                        DoubleX4 temp0, temp1, temp2, temp3;
                        // c4.load(&cos_it[0]), s4.loadu(&sin_it[0]);
                        temp0.load(&it2[0]), temp1.loadu(&it3[0]);
                        temp2 = (temp1.reverse() * s4).fmadd(temp0, c4);
                        temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4.reverse());
                        temp0.load(&it0[0]), temp1.loadu(&it1[0]);
                        (temp0 + temp2).store(&it0[0]);
                        (temp1 + temp3).storeu(&it1[0]);
                        (temp0 - temp2).store(&it2[0]);
                        (temp1 - temp3).storeu(&it3[0]);
                        temp0 = c4, temp1 = s4;
                        c4 = (temp0 * cos_unit - temp1 * sin_unit);
                        s4 = (temp0 * sin_unit + temp1 * cos_unit);
                    }
                }
                template <typename FloatIt>
                static void dif(FloatIt in_out)
                {
                    static_assert(std::is_same<typename std::iterator_traits<FloatIt>::value_type, FloatTy>::value, "Must be same as the FHT template float type");

                    auto it0 = in_out + 1, it1 = in_out + half_len - 1;
                    auto it2 = it0 + half_len, it3 = it1 + half_len;
                    static FloatTy cos_arr[8] = {1, cos_omega(1), cos_omega(2), cos_omega(3), cos_omega(4), cos_omega(5), cos_omega(6), cos_omega(7)};
                    static FloatTy sin_arr[8] = {0, sin_omega(1), sin_omega(2), sin_omega(3), sin_omega(4), sin_omega(5), sin_omega(6), sin_omega(7)};
                    for (; it0 < in_out + 4; ++it0, --it1, ++it2, --it3)
                    {
                        auto c = cos_arr[it0 - in_out], s = sin_arr[it0 - in_out]; //+,-
                        auto temp0 = it0[0], temp1 = it1[0];                       //+,-
                        auto temp2 = it2[0], temp3 = it3[0];                       //+,-
                        it0[0] = temp0 + temp2;                                    //+
                        it1[0] = temp1 + temp3;                                    //-
                        temp0 = temp0 - temp2;                                     //+
                        temp1 = temp1 - temp3;                                     //-
                        it2[0] = temp0 * c + temp1 * s;                            //+*+  -(-*-)
                        it3[0] = temp0 * s - temp1 * c;                            //-(+)*- -*-(+)
                    }
                    it1 -= 3, it3 -= 3;
                    static const DoubleX4 cos_unit(cos_arr[4]), sin_unit(sin_arr[4]);
                    DoubleX4 c4(cos_arr + 4), s4(sin_arr + 4);
                    for (; it0 < in_out + quarter_len; it0 += 4, it1 -= 4, it2 += 4, it3 -= 4)
                    {
                        DoubleX4 temp0, temp1, temp2, temp3;
                        temp0.load(&it0[0]), temp1.loadu(&it1[0]);
                        temp2.load(&it2[0]), temp3.loadu(&it3[0]);
                        (temp0 + temp2).store(&it0[0]);
                        (temp1 + temp3).storeu(&it1[0]);
                        temp0 = temp0 - temp2;
                        temp1 = temp1 - temp3;
                        temp2 = (temp1.reverse() * s4).fmadd(temp0, c4);
                        temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4.reverse());
                        temp2.store(&it2[0]), temp3.storeu(&it3[0]);
                        temp0 = c4, temp1 = s4;
                        c4 = (temp0 * cos_unit - temp1 * sin_unit);
                        s4 = (temp0 * sin_unit + temp1 * cos_unit);
                    }
                    transform_2point(in_out[0], in_out[half_len]);
                    transform_2point(in_out[quarter_len], in_out[half_len + quarter_len]);
                    HalfFHT::dif(in_out);
                    HalfFHT::dif(in_out + half_len);
                }
            };

            template <typename FloatTy>
            struct FHT<0, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out) {}
                template <typename FloatIt>
                static void dif(FloatIt in_out) {}
                template <typename FloatIt>
                static void dit_avx(FloatIt in_out) {}
                template <typename FloatIt>
                static void dif_avx(FloatIt in_out) {}
            };

            template <typename FloatTy>
            struct FHT<1, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out) {}
                template <typename FloatIt>
                static void dif(FloatIt in_out) {}
            };

            template <typename FloatTy>
            struct FHT<2, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out)
                {
                    transform_2point(in_out[0], in_out[1]);
                }
                template <typename FloatIt>
                static void dif(FloatIt in_out)
                {
                    transform_2point(in_out[0], in_out[1]);
                }
            };

            template <typename FloatTy>
            struct FHT<4, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out)
                {
                    auto temp0 = in_out[0], temp1 = in_out[1];
                    auto temp2 = in_out[2], temp3 = in_out[3];
                    transform_2point(temp0, temp1);
                    transform_2point(temp2, temp3);
                    in_out[0] = temp0 + temp2;
                    in_out[1] = temp1 + temp3;
                    in_out[2] = temp0 - temp2;
                    in_out[3] = temp1 - temp3;
                }
                template <typename FloatIt>
                static void dif(FloatIt in_out)
                {
                    auto temp0 = in_out[0], temp1 = in_out[1];
                    auto temp2 = in_out[2], temp3 = in_out[3];
                    transform_2point(temp0, temp2);
                    transform_2point(temp1, temp3);
                    in_out[0] = temp0 + temp1;
                    in_out[1] = temp0 - temp1;
                    in_out[2] = temp2 + temp3;
                    in_out[3] = temp2 - temp3;
                }
            };

            template <typename FloatTy>
            struct FHT<8, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out)
                {
                    auto temp0 = in_out[0], temp1 = in_out[1];
                    auto temp2 = in_out[2], temp3 = in_out[3];
                    auto temp4 = in_out[4], temp5 = in_out[5];
                    auto temp6 = in_out[6], temp7 = in_out[7];
                    transform_2point(temp0, temp1);
                    transform_2point(temp2, temp3);
                    transform_2point(temp4, temp5);
                    transform_2point(temp6, temp7);
                    transform_2point(temp0, temp2);
                    transform_2point(temp1, temp3);
                    transform_2point(temp4, temp6);
                    static constexpr decltype(temp0) SQRT_2 = 1.4142135623730950488016887242097;
                    temp5 *= SQRT_2, temp7 *= SQRT_2;
                    in_out[0] = temp0 + temp4;
                    in_out[1] = temp1 + temp5;
                    in_out[2] = temp2 + temp6;
                    in_out[3] = temp3 + temp7;
                    in_out[4] = temp0 - temp4;
                    in_out[5] = temp1 - temp5;
                    in_out[6] = temp2 - temp6;
                    in_out[7] = temp3 - temp7;
                }
                template <typename FloatIt>
                static void dif(FloatIt in_out)
                {
                    auto temp0 = in_out[0], temp1 = in_out[1];
                    auto temp2 = in_out[2], temp3 = in_out[3];
                    auto temp4 = in_out[4], temp5 = in_out[5];
                    auto temp6 = in_out[6], temp7 = in_out[7];
                    transform_2point(temp0, temp4);
                    transform_2point(temp1, temp5);
                    transform_2point(temp2, temp6);
                    transform_2point(temp3, temp7);
                    transform_2point(temp0, temp2);
                    transform_2point(temp1, temp3);
                    static constexpr decltype(temp0) SQRT_2 = 1.4142135623730950488016887242097;
                    temp5 *= SQRT_2, temp7 *= SQRT_2;
                    transform_2point(temp4, temp6);
                    in_out[0] = temp0 + temp1;
                    in_out[1] = temp0 - temp1;
                    in_out[2] = temp2 + temp3;
                    in_out[3] = temp2 - temp3;
                    in_out[4] = temp4 + temp5;
                    in_out[5] = temp4 - temp5;
                    in_out[6] = temp6 + temp7;
                    in_out[7] = temp6 - temp7;
                }
            };
            template <typename FloatTy>
            struct FHT<16, FloatTy>
            {
                template <typename FloatIt>
                static void dit(FloatIt in_out)
                {
                    using value_type = typename std::iterator_traits<FloatIt>::value_type;
                    FHT<4, FloatTy>::dit(in_out + 12);
                    FHT<4, FloatTy>::dit(in_out + 8);
                    FHT<8, FloatTy>::dit(in_out);
                    static constexpr value_type SQRT_2 = 1.4142135623730950488016887242097;
                    static constexpr value_type COS_16 = 0.9238795325112867561281831893967; // cos(2PI/16);
                    static constexpr value_type SIN_16 = 0.3826834323650897717284599840304; // sin(2PI/16);
                    auto temp4 = in_out[9], temp5 = in_out[11];
                    auto temp0 = temp4 * COS_16 + temp5 * SIN_16;
                    auto temp2 = temp4 * SIN_16 - temp5 * COS_16;

                    temp4 = in_out[13], temp5 = in_out[15];
                    auto temp1 = temp4 * SIN_16 + temp5 * COS_16;
                    auto temp3 = temp4 * COS_16 - temp5 * SIN_16;

                    transform_2point(temp0, temp1);
                    transform_2point(temp3, temp2);

                    temp4 = in_out[1], temp5 = in_out[3];
                    in_out[1] = temp4 + temp0, in_out[3] = temp5 + temp1;
                    in_out[9] = temp4 - temp0, in_out[11] = temp5 - temp1;

                    temp4 = in_out[5], temp5 = in_out[7];
                    in_out[5] = temp4 + temp2, in_out[7] = temp5 + temp3;
                    in_out[13] = temp4 - temp2, in_out[15] = temp5 - temp3;

                    in_out[10] *= SQRT_2, in_out[14] *= SQRT_2;
                    transform_2point(in_out[8], in_out[12]);
                    transform_2point(in_out[0], in_out[8]);
                    transform_2point(in_out[2], in_out[10]);
                    transform_2point(in_out[4], in_out[12]);
                    transform_2point(in_out[6], in_out[14]);
                }
                template <typename FloatIt>
                static void dif(FloatIt in_out)
                {
                    using value_type = typename std::iterator_traits<FloatIt>::value_type;
                    static constexpr value_type SQRT_2 = 1.4142135623730950488016887242097;
                    static constexpr value_type COS_16 = 0.9238795325112867561281831893967; // cos(2PI/16);
                    static constexpr value_type SIN_16 = 0.3826834323650897717284599840304; // sin(2PI/16);
                    transform_2point(in_out[0], in_out[8]);
                    transform_2point(in_out[2], in_out[10]);
                    transform_2point(in_out[4], in_out[12]);
                    transform_2point(in_out[6], in_out[14]);
                    transform_2point(in_out[8], in_out[12]);
                    in_out[10] *= SQRT_2, in_out[14] *= SQRT_2;

                    auto temp0 = in_out[9], temp1 = in_out[11];
                    auto temp2 = in_out[13], temp3 = in_out[15];

                    auto temp4 = in_out[1], temp5 = in_out[3];
                    in_out[1] = temp4 + temp0;
                    in_out[3] = temp5 + temp1;
                    temp0 = temp4 - temp0;
                    temp1 = temp5 - temp1;

                    temp4 = in_out[5], temp5 = in_out[7];
                    in_out[5] = temp4 + temp2;
                    in_out[7] = temp5 + temp3;
                    temp2 = temp4 - temp2;
                    temp3 = temp5 - temp3;

                    transform_2point(temp0, temp1);
                    transform_2point(temp3, temp2);

                    in_out[9] = temp0 * COS_16 + temp2 * SIN_16;
                    in_out[11] = temp0 * SIN_16 - temp2 * COS_16;
                    in_out[13] = temp1 * SIN_16 + temp3 * COS_16;
                    in_out[15] = temp1 * COS_16 - temp3 * SIN_16;

                    FHT<8, FloatTy>::dif(in_out);
                    FHT<4, FloatTy>::dif(in_out + 8);
                    FHT<4, FloatTy>::dif(in_out + 12);
                }
            };

            // 辅助选择函数
            template <size_t LEN = 1>
            inline void fht_dit_template_alt(Float64 *input, size_t fht_len)
            {
                if (fht_len < LEN)
                {
                    fht_dit_template_alt<LEN / 2>(input, fht_len);
                    return;
                }
                FHT<LEN, Float64>::dit(input);
            }
            template <>
            inline void fht_dit_template_alt<0>(Float64 *input, size_t fht_len) {}

            // 辅助选择函数
            template <size_t LEN = 1>
            inline void fht_dif_template_alt(Float64 *input, size_t fht_len)
            {
                if (fht_len < LEN)
                {
                    fht_dif_template_alt<LEN / 2>(input, fht_len);
                    return;
                }
                FHT<LEN, Float64>::dif(input);
            }
            template <>
            inline void fht_dif_template_alt<0>(Float64 *input, size_t fht_len) {}

            auto fht_dit = fht_dit_template_alt<FHT_MAX_LEN>;
            auto fht_dif = fht_dif_template_alt<FHT_MAX_LEN>;

            // FHT加速卷积
            inline void fht_convolution(Float64 fht_ary1[], Float64 fht_ary2[], Float64 out[], size_t fht_len)
            {
                if (fht_len == 0)
                {
                    return;
                }
                if (fht_len == 1)
                {
                    out[0] = fht_ary1[0] * fht_ary2[0];
                    return;
                }
                fht_len = int_floor2(fht_len);
                if (fht_len > FHT_MAX_LEN)
                {
                    throw("FHT len cannot be larger than FHT_MAX_LEN");
                }
                fht_dif(fht_ary1, fht_len);
                // 两个输入相同时只进行一次计算,提升平方速度
                if (fht_ary1 != fht_ary2)
                {
                    fht_dif(fht_ary2, fht_len);
                }
                const double inv = 0.5 / fht_len;
                out[0] = fht_ary1[0] * fht_ary2[0] / fht_len;
                out[1] = fht_ary1[1] * fht_ary2[1] / fht_len;
                if (fht_len == 2)
                {
                    return;
                }
                // DHT的卷积定理
                auto temp0 = fht_ary1[2], temp1 = fht_ary1[3];
                auto temp2 = fht_ary2[2], temp3 = fht_ary2[3];
                transform_2point(temp0, temp1);
                out[2] = (temp2 * temp0 + temp3 * temp1) * inv;
                out[3] = (temp3 * temp0 - temp2 * temp1) * inv;
                for (size_t i = 4; i < fht_len; i *= 2)
                {
                    auto it0 = fht_ary1 + i, it1 = it0 + i - 1;
                    auto it2 = fht_ary2 + i, it3 = it2 + i - 1;
                    auto it4 = out + i, it5 = it4 + i - 1;
                    for (; it0 < it1; it0 += 2, it1 -= 2, it2 += 2, it3 -= 2, it4 += 2, it5 -= 2)
                    {
                        temp0 = *it0, temp1 = *it1, temp2 = *it2, temp3 = *it3;
                        transform_2point(temp0, temp1);
                        *it4 = (temp2 * temp0 + temp3 * temp1) * inv;
                        *it5 = (temp3 * temp0 - temp2 * temp1) * inv;
                        temp0 = *(it1 - 1), temp1 = *(it0 + 1), temp2 = *(it3 - 1), temp3 = *(it2 + 1);
                        transform_2point(temp0, temp1);
                        *(it5 - 1) = (temp2 * temp0 + temp3 * temp1) * inv;
                        *(it4 + 1) = (temp3 * temp0 - temp2 * temp1) * inv;
                    }
                }
                fht_dit(out, fht_len);
            }
        }
    }
}

using namespace std;
using namespace hint;
using namespace hint_simd;
using namespace hint_transform;
using namespace hint_fht;

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    size_t conv_len = m + n + 1, fht_len = int_ceil2(conv_len);
    static AlignAry<Float64, FHT_MAX_LEN> buffer1, buffer2;
    std::copy(a, a + n + 1, buffer1.data());
    std::copy(b, b + n + 1, buffer2.data());
    fht_convolution(buffer1.data(), buffer2.data(), buffer1.data(), FHT_MAX_LEN);
    size_t i = 0, rem = conv_len % 4;
    for (; i < conv_len - rem; i += 4)
    {
        c[i] = buffer1[i] + 0.5;
        c[i + 1] = buffer1[i + 1] + 0.5;
        c[i + 2] = buffer1[i + 2] + 0.5;
        c[i + 3] = buffer1[i + 3] + 0.5;
    }
    for (; i < conv_len; i++)
    {
        c[i] = buffer1[i] + 0.5;
    }
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #144.4 ms39 MB + 700 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-07-20 22:37:27 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠