提交记录 18828


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002. 测测你的多项式乘法 Wrong Answer 0 252.529 ms 73384 KB C++ 8.21 KB
提交时间 评测时间
2022-12-25 20:41:00 2022-12-25 20:41:04
#include <iostream>
#include <complex>
#include <cstring>
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using Complex = std::complex<double>;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = 6.283185307179586476925286766559;
inline Complex unit_root(double theta)
{
    return std::polar<double>(1, theta);
}
class CosTable
{
private:
    Complex *table = nullptr;
    size_t log_size = 0;
    CosTable(const CosTable &) = delete;
    CosTable &operator=(const CosTable &) = delete;

public:
    ~CosTable()
    {
        if (table != nullptr)
        {
            delete[] table;
            table = nullptr;
        }
    }
    // 初始化可以生成平分圆1<<shift份产生的单位根的表
    CosTable(size_t shift)
    {
        shift = std::max<size_t>(shift, 3);
        log_size = shift;
        size_t ary_size = (1ull << (shift - 1)) - 2;
        table = new Complex[ary_size];
        shift -= 2;
        for (size_t i = 1; i <= shift; i++)
        {
            size_t len = 1ull << i;
            size_t begin = len - 2;
            for (size_t pos = 0; pos < len; pos++)
            {
                table[pos + begin] = unit_root(pos * HINT_PI / (len * 2));
            }
        }
    }
    // shift表示圆平分为1<<shift份,n表示第几个单位根
    Complex get_complex(size_t shift, size_t n) const
    {
        size_t rank = 1ull << shift;
        n &= (rank - 1);
        size_t zone = (n << 2) >> shift;
        if (((n << 2) & (rank - 1)) == 0)
        {
            Complex ary[4] = {Complex(1, 0), Complex(0, 1), Complex(-1, 0), Complex(0, -1)};
            return ary[zone];
        }
        rank >>= 2;
        const Complex *ptr = table + rank - 2;
        Complex tmp;
        switch (zone)
        {
        case 0:
            tmp = ptr[n];
            break;
        case 1:
            tmp = ptr[(rank << 1) - n];
            tmp = Complex(-tmp.real(), tmp.imag());
            break;
        case 2:
            tmp = -ptr[n - (rank << 1)];
            break;
        case 3:
            tmp = std::conj(ptr[(rank << 2) - n]);
            break;
        default:
            break;
        }
        return tmp;
    }
};
constexpr size_t lut_max_len = 21;
const CosTable TABLE(lut_max_len); // 初始化fft表
inline UINT_64 max_2pow(UINT_64 n)
{
    return 1ull << static_cast<UINT_16>(std::floor(std::log2(n)));
}
inline UINT_64 min_2pow(UINT_64 n)
{
    return 1ull << static_cast<UINT_16>(std::ceil(std::log2(n)));
}
template <typename T>
constexpr bool is_odd(T x)
{
    return static_cast<bool>(x & 1);
}
template <typename T>
inline T *ary_copy(T *target, const T *source, size_t len)
{
    return static_cast<T *>(std::memcpy(target, source, len * sizeof(T)));
}
inline void fft_radix2_butterfly(Complex omega, Complex *input, size_t pos, size_t rank)
{
    Complex tmp1 = input[pos];
    Complex tmp2 = input[pos + rank] * omega;
    input[pos] += tmp2;
    input[pos + rank] = tmp1 - tmp2;
}
inline void fft_radix4_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
                                 Complex *input, size_t pos, size_t rank)
{
    Complex tmp1 = input[pos];
    Complex tmp2 = input[pos + rank] * omega;
    Complex tmp3 = input[pos + rank * 2] * omega_sqr;
    Complex tmp4 = input[pos + rank * 3] * omega_cube;

    Complex t1 = tmp1 + tmp3;
    Complex t2 = tmp2 + tmp4;
    Complex t3 = tmp1 - tmp3;
    Complex t4 = tmp2 - tmp4;
    t4 = Complex(-1 * t4.imag(), t4.real());

    input[pos] = t1 + t2;
    input[pos + rank] = t3 + t4;
    input[pos + rank * 2] = t1 - t2;
    input[pos + rank * 3] = t3 - t4;
}
template <typename T>
constexpr void binary_inverse_swap(T *ary, size_t len)
{
    size_t log_n = static_cast<UINT_16>(log2(len));
    size_t *rev = new size_t[len / 2];
    rev[0] = 0;
    for (size_t i = 1; i < len; i++)
    {
        size_t index = (rev[i >> 1] >> 1) | ((i & 1) << (log_n - 1)); // 求rev交换数组
        if (i < len / 2)
        {
            rev[i] = index;
        }
        if (i < index)
        {
            std::swap(ary[i], ary[index]);
        }
    }
    delete[] rev;
}
// 四进制逆序
template <typename T>
constexpr void quaternary_inverse_swap(T *ary, size_t len)
{
    size_t log_n = static_cast<UINT_16>(log2(len));
    size_t *rev = new size_t[len / 4];
    rev[0] = 0;
    for (size_t i = 1; i < len; i++)
    {
        size_t index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
        if (i < len / 4)
        {
            rev[i] = index;
        }
        if (i < index)
        {
            std::swap(ary[i], ary[index]);
        }
    }
    delete[] rev;
}
inline void fft_conj(Complex *input, size_t fft_len, double div = 1)
{
    for (size_t i = 0; i < fft_len; i++)
    {
        input[i] = std::conj(input[i]) / div;
    }
}
// 基2
void fft_radix2_lut(Complex *input, size_t fft_len)
{
    fft_len = max_2pow(fft_len);
    if (fft_len > (1ull << lut_max_len))
    {
        throw("fft length too long for lut\n");
    }
    binary_inverse_swap(input, fft_len);
    UINT_8 log_rank = 1;
    for (size_t rank = 1; rank < fft_len; rank *= 2)
    {
        size_t gap = rank * 2;
        for (size_t begin = 0; begin < fft_len; begin += gap)
        {
            for (size_t pos = begin; pos < begin + rank; pos++)
            {
                Complex omega = TABLE.get_complex(log_rank, pos - begin);
                fft_radix2_butterfly(omega, input, pos, rank);
            }
        }
        log_rank++;
    }
}
// 基4查表快速傅里叶变换
void fft_radix4_lut(Complex *input, size_t fft_len)
{
    size_t log4_len = std::log2(fft_len) / 2;
    fft_len = 1ull << (log4_len * 2);
    if (fft_len > (1ull << lut_max_len))
    {
        throw("fft length too long for lut\n");
    }
    quaternary_inverse_swap(input, fft_len);
    UINT_16 log_rank = 2;
    for (size_t rank = 1; rank < fft_len; rank *= 4)
    {
        size_t gap = rank * 4;
        for (size_t begin = 0; begin < fft_len; begin += gap)
        {
            for (size_t pos = begin; pos < begin + rank; pos++)
            {
                Complex omega = TABLE.get_complex(log_rank, pos - begin);
                Complex omega_sqr = TABLE.get_complex(log_rank, (pos - begin) << 1);
                Complex omega_cube = omega * omega_sqr;
                fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
            }
        }
        log_rank += 2;
    }
}
void fft_lut(Complex *input, size_t fft_len)
{
    fft_len = max_2pow(fft_len);
    size_t log_len = std::log2(fft_len);
    if (is_odd(log_len))
    {
        size_t half_len = fft_len / 2;
        Complex *tmp_ary = new Complex[half_len];
        for (size_t i = 0; i < fft_len; i += 2)
        {
            input[i / 2] = input[i];
            tmp_ary[i / 2] = input[i + 1];
        }
        ary_copy(input + half_len, tmp_ary, half_len);
        delete[] tmp_ary;
        fft_radix4_lut(input, half_len);
        fft_radix4_lut(input + half_len, half_len);

        for (size_t i = 0; i < half_len; i++)
        {
            Complex omega = TABLE.get_complex(log_len, i);
            fft_radix2_butterfly(omega, input, i, half_len);
        }
    }
    else
    {
        fft_radix4_lut(input, fft_len);
    }
}

void ifft_lut(Complex *input, size_t fft_len)
{
    fft_len = max_2pow(fft_len);
    fft_conj(input, fft_len);
    fft_lut(input, fft_len);
    fft_conj(input, fft_len, fft_len);
}

template <typename T>
inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
{
    for (size_t i = 0; i < len; i++)
    {
        out[i] = in1[i] * in2[i];
    }
}

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    if (n < m)
    {
        std::swap(a, b);
        std::swap(n, m);
    }
    size_t fft_len = min_2pow(n + m - 1);
    Complex *fft_ary = new Complex[fft_len];
    UINT_32 tmp = 0;
    size_t pos = 0;
    while (pos < m)
    {
        fft_ary[pos] = Complex(a[pos], b[pos]);
        pos++;
    }
    while (pos < n)
    {
        fft_ary[pos] = Complex(a[pos], 0);
        pos++;
    }
    fft_lut(fft_ary, fft_len);
    ary_mul(fft_ary, fft_ary, fft_ary, fft_len);
    ifft_lut(fft_ary, fft_len);
    for (size_t i = 0; i < m + n - 1; i++)
    {
        c[i] = static_cast<unsigned>(fft_ary[i].imag() / 2 + 0.5);
    }
    delete[] fft_ary;
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1252.529 ms71 MB + 680 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-16 17:01:22 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠