提交记录 21641


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY test. 自定义测试 Runtime Error 0 33.52 us 32 KB C++14 12.66 KB
提交时间 评测时间
2024-04-24 15:51:20 2024-04-24 15:51:22
#include <cstdint>
#include <iostream>
#include <chrono>
#include <immintrin.h>
#include <tuple>
#pragma GCC target("fma")
#pragma GCC target("avx2")

// 模板快速幂
template <typename T, typename T1>
constexpr T qpow(T m, T1 n)
{
    T result = 1;
    while (n > 0)
    {
        if ((n & 1) != 0)
        {
            result *= m;
        }
        m *= m;
        n >>= 1;
    }
    return result;
}

template <typename T>
constexpr int int_log2(const T &n)
{
    constexpr int bits = sizeof(n) * 8;
    int l = -1, r = bits;
    while ((l + 1) != r)
    {
        int mid = (l + r) / 2;
        if ((T(1) << mid) > n)
        {
            r = mid;
        }
        else
        {
            l = mid;
        }
    }
    return l;
}

template <typename T>
constexpr T int_ceil2(T n)
{
    constexpr int bits = sizeof(n) * 8;
    n--;
    for (int i = 1; i < bits; i *= 2)
    {
        n |= (n >> i);
    }
    return n + 1;
}

// bits个二进制全为1的数,等于2^bits-1
template <typename T>
constexpr T all_one(int bits)
{
    T tmp = T(1) << (bits - 1);
    return tmp - 1 + tmp;
}

// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
    const uint64_t mask = all_one<uint64_t>(pow);
    uint64_t xn = 1, t = n & mask;
    while (t != 1)
    {
        xn = (xn * (2 - t));
        t = (xn * n) & mask;
    }
    return xn & mask;
}

template <uint32_t MOD>
struct MontInt32
{
    static constexpr uint64_t R = uint64_t(1) << 32;
    static constexpr uint32_t R_MASK = R - 1;
    static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
    static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
    static constexpr uint32_t MOD2 = MOD * 2;
    static_assert(int_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
    static_assert(uint32_t(MOD_INV * MOD) == 1, "Montgomery32 modulus is not correct");

    uint32_t data;
    constexpr MontInt32() : data(0) {}
    constexpr MontInt32(uint32_t n) : data(toMont(n)) {}

    static constexpr uint32_t toMont(uint32_t n)
    {
        return (uint64_t(n) << 32) % MOD;
    }
    static constexpr uint32_t redc(uint64_t input)
    {
        uint64_t n = uint32_t(input) * MOD_INV_NEG;
        n = n * MOD + input;
        n >>= 32;
        return n < MOD ? n : n - MOD;
    }
    static constexpr uint32_t toInt(uint32_t n)
    {
        return redc(n);
    }
    static constexpr uint32_t addMont(uint32_t m, uint32_t n)
    {
        n = m + n;
        return n < MOD ? n : n - MOD;
    }
    static constexpr uint32_t subMont(uint32_t m, uint32_t n)
    {
        n = m - n;
        return n > m ? n + MOD : n;
    }
    static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
    {
        return redc(uint64_t(m) * n);
    }
    constexpr void fromInt(uint32_t n)
    {
        data = toMont(n);
    }
    constexpr uint32_t toInt() const
    {
        return toInt(data);
    }
    constexpr operator uint32_t() const
    {
        return toInt();
    }
    constexpr MontInt32 operator+(MontInt32 rhs) const
    {
        rhs.data = addMont(data, rhs.data);
        return rhs;
    }
    constexpr MontInt32 operator-(MontInt32 rhs) const
    {
        rhs.data = subMont(data, rhs.data);
        return rhs;
    }
    constexpr MontInt32 operator*(MontInt32 rhs) const
    {
        rhs.data = mulMont(data, rhs.data);
        return rhs;
    }
    constexpr MontInt32 &operator+=(const MontInt32 &rhs)
    {
        data = addMont(data, rhs.data);
        return *this;
    }
    constexpr MontInt32 &operator-=(const MontInt32 &rhs)
    {
        data = subMont(data, rhs.data);
        return *this;
    }
    constexpr MontInt32 &operator*=(const MontInt32 &rhs)
    {
        data = mulMont(data, rhs.data);
        return *this;
    }
    static constexpr uint32_t mod()
    {
        return MOD;
    }
};

template <uint32_t MOD>
struct MontInt32X8
{
    using MontInt = MontInt32<MOD>;
    __m256i data;

    MontInt32X8() { data = _mm256_setzero_si256(); }
    MontInt32X8(MontInt x) { data = _mm256_set1_epi32(x.data); }
    MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
    {
        data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
    }
    MontInt32X8(__m256i rhs) : data(rhs) {}
    template <typename T>
    MontInt32X8(const T *p)
    {
        loadu(p);
    }
    static constexpr uint32_t mod()
    {
        return MOD;
    }
    static MontInt32X8 zeroX8()
    {
        return _mm256_setzero_si256();
    }
    static MontInt32X8 modX8()
    {
        return _mm256_set1_epi32(MOD);
    }
    static MontInt32X8 mod1X8()
    {
        return _mm256_set1_epi32(MOD - 1);
    }
    static MontInt32X8 mod2X8()
    {
        return _mm256_set1_epi32(MOD * 2);
    }
    static MontInt32X8 modNX8()
    {
        return _mm256_set1_epi32(MontInt::MOD_INV_NEG);
    }

    MontInt32X8 mul64(MontInt32X8 rhs) const
    {
        return _mm256_mul_epu32(data, rhs.data);
    }
    MontInt32X8 lShift64(int n) const
    {
        return _mm256_slli_epi64(data, n);
    }
    MontInt32X8 rShift64(int n) const
    {
        return _mm256_srli_epi64(data, n);
    }
    MontInt32X8 evenElements() const
    {
        return blend<0b10101010>(data, zeroX8());
    }
    MontInt32X8 oddElements() const
    {
        return blend<0b01010101>(data, zeroX8());
    }
    std::pair<MontInt32X8, MontInt32X8> mul64hl(MontInt32X8 rhs) const
    {
        return std::make_pair(mul64(rhs), rShift64(32).mul64(rhs.rShift64(32)));
    }
    // MontInt32X8 getW1() const
    // {
    //     alignas(32) uint32_t temp[8];
    //     store(temp);
    //     for (auto &&i : temp)
    //     {
    //         i = (uint64_t(i) << 32) / mod;
    //     }
    //     return MontInt32X8(temp);
    // }
    // MontInt32X8 mulModShoup(MontInt32X8 w, MontInt32X8 w1) const
    // {
    //     MontInt32X8 q0, q1, t0, t1;
    //     std::tie(q0, q1) = mul64hl(w1);
    //     std::tie(t0, t1) = mul64hl(w);
    //     q0 = q0.rShift64(32), q1 = q1.rShift64(32);
    //     q0 = q0.mul64(MontInt32X8(mod)), q1 = q1.mul64(MontInt32X8(mod));
    //     t0 = t0.rawSub64(q0), t1 = t1.rawSub64(q1);
    //     return t0 | t1.lShift64(32);
    // }
    static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
    {
        MontInt32X8 p0 = even64.mul64(modNX8());
        MontInt32X8 p1 = odd64.mul64(modNX8());
        p0 = p0.mul64(modX8()).rawAdd64(even64).rShift64(32);
        p1 = p1.mul64(modX8()).rawAdd64(odd64);
        return blend<0b10101010>(p0, p1);
    }
    static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
    {
        return montRedcLazy(even64, odd64).largeNorm();
    }
    MontInt32X8 toMont() const
    {
        alignas(32) uint32_t temp[8];
        store(temp);
        for (auto &&i : temp)
        {
            i = (uint64_t(i) << 32) % MOD;
        }
        return MontInt32X8(temp);
    }
    MontInt32X8 toInt() const
    {
        MontInt32X8 e = evenElements();
        MontInt32X8 o = rShift64(32);
        return montRedc(e, o);
    }
    template <int N>
    static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
    {
        return _mm256_blend_epi32(a.data, b.data, N);
    }
    MontInt32X8 largeNorm() const
    {
        MontInt32X8 sub = (*this > mod1X8()) & modX8();
        return rawSub(sub);
    }
    MontInt32X8 smallNorm() const
    {
        MontInt32X8 add = (zeroX8() > *this) & modX8();
        return rawAdd(add);
    }
    MontInt32X8 smallNorm2() const
    {
        MontInt32X8 add = (zeroX8() > *this) & mod2X8();
        return rawAdd(add);
    }
    MontInt32X8 addMont(MontInt32X8 rhs) const
    {
        return rawSub(mod2X8()).rawAdd(rhs).smallNorm2();
    }
    MontInt32X8 subMont(MontInt32X8 rhs) const
    {
        return rawSub(rhs).smallNorm2();
    }
    MontInt32X8 mulMont(MontInt32X8 rhs) const
    {
        auto mulhl = mul64hl(rhs);
        return montRedc(mulhl.first, mulhl.second);
    }

    MontInt32X8 operator+(MontInt32X8 rhs) const
    {
        return addMont(rhs);
    }
    MontInt32X8 operator-(MontInt32X8 rhs) const
    {
        return subMont(rhs);
    }
    MontInt32X8 operator*(MontInt32X8 rhs) const
    {
        return mulMont(rhs);
    }
    MontInt32X8 rawAdd(MontInt32X8 rhs) const
    {
        return _mm256_add_epi32(data, rhs.data);
    }
    MontInt32X8 rawSub(MontInt32X8 rhs) const
    {
        return _mm256_sub_epi32(data, rhs.data);
    }
    MontInt32X8 rawAdd64(MontInt32X8 rhs) const
    {
        return _mm256_add_epi64(data, rhs.data);
    }
    MontInt32X8 rawSub64(MontInt32X8 rhs) const
    {
        return _mm256_sub_epi64(data, rhs.data);
    }

    MontInt32X8 operator>(MontInt32X8 n) const
    {
        return _mm256_cmpgt_epi32(data, n.data);
    }
    MontInt32X8 operator<(MontInt32X8 n) const
    {
        return n > *this;
    }
    MontInt32X8 operator==(MontInt32X8 n) const
    {
        return _mm256_cmpeq_epi32(data, n.data);
    }
    MontInt32X8 operator&(MontInt32X8 n) const
    {
        return _mm256_and_si256(data, n.data);
    }
    MontInt32X8 operator|(MontInt32X8 n) const
    {
        return _mm256_or_si256(data, n.data);
    }
    MontInt32X8 operator^(MontInt32X8 n) const
    {
        return _mm256_xor_si256(data, n.data);
    }
    template <typename T>
    void loadu(const T *p)
    {
        data = _mm256_loadu_si256((const __m256i *)p);
    }
    template <typename T>
    void load(const T *p)
    {
        data = _mm256_load_si256((const __m256i *)p);
        // data = *reinterpret_cast<const __m256i *>(p);
    }
    template <typename T>
    void storeu(T *p) const
    {
        _mm256_storeu_si256((__m256i *)p, data);
    }
    template <typename T>
    void store(T *p) const
    {
        _mm256_store_si256((__m256i *)p, data);
        // *reinterpret_cast<__m256i *>(p) = data;
    }
    void printI32() const
    {
        alignas(32) int32_t v[8];
        store(v);
        std::cout << "[" << v[0] << "," << v[1]
                  << "," << v[2] << "," << v[3]
                  << "," << v[4] << "," << v[5]
                  << "," << v[6] << "," << v[7] << "]" << std::endl;
    }
    void printU32() const
    {
        alignas(32) uint32_t v[8];
        store(v);
        std::cout << "[" << v[0] << "," << v[1]
                  << "," << v[2] << "," << v[3]
                  << "," << v[4] << "," << v[5]
                  << "," << v[6] << "," << v[7] << "]" << std::endl;
    }
    void printU64() const
    {
        alignas(32) uint64_t v[4];
        store(v);
        std::cout << "[" << v[0] << "," << v[1]
                  << "," << v[2] << "," << v[3] << "]" << std::endl;
    }
};

// void test_x8()
// {
//     constexpr size_t lo = 1e8;
//     constexpr uint32_t mod = 998244353;
//     MontInt32<mod> a(3), b(mod - 2);
//     MontInt32X8<mod> c(a), d(b);
//     MontInt32X8<mod> e(c), f(d);
//     auto t1 = std::chrono::steady_clock::now();
//     for (size_t i = 0; i < lo; i++)
//     {
//         // a = a * b;
//     }
//     auto t2 = std::chrono::steady_clock::now();
//     for (size_t i = 0; i < lo; i++)
//     {
//         a = a - b;
//     }
//     auto t3 = std::chrono::steady_clock::now();
//     // c = c.toMont();
//     // d = d.toMont();
//     for (size_t i = 0; i < lo; i++)
//     {
//         c = c - d;
//     }
//     c = c.toInt();
//     f = e.toInt();
//     auto t4 = std::chrono::steady_clock::now();
//     std::cout << uint32_t(a.toInt()) << "\n";
//     c.printU32();
//     e.printU32();

//     auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
//     auto time2 = std::chrono::duration_cast<std::chrono::duration<double>>(t3 - t2).count();
//     auto time3 = std::chrono::duration_cast<std::chrono::duration<double>>(t4 - t3).count();
//     std::cout << time1 << "s " << time2 << "s " << time3 << "s\n";
// }

void avx2_test()
{
    constexpr size_t len = 1 << 5;
    constexpr uint32_t mod = 998244353;
    using NTTX8 = MontInt32X8<mod>;
    using ModInt = typename NTTX8::MontInt;
    // alignas(64) static ModInt a[len];
    auto a = (ModInt *)_mm_malloc(len * sizeof(ModInt), 32);
    for (size_t i = 0; i < len; i++)
    {
        a[i].data = i;
        // b[i] = i;
    }
    size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
    auto t1 = std::chrono::steady_clock::now();
    for (size_t i = 0; i < times; i++)
    {
        __m256i x = *(__m256i *)a;
        x = _mm256_add_epi32(x, x);

        // NTTX8 x;
        // x.data = *(__m256i *)a;
        // x = x * x;
        *(__m256i *)a = x;
    }
    auto t2 = std::chrono::steady_clock::now();
    auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
    for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
    {
        std::cout << i << ":\t" << uint32_t(a[i].data) << "\n";
    }
    std::cout << time1 << "\n";
}

int main()
{
    // test_x8();
    avx2_test();
    return 0;
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #133.52 us32 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-07-18 14:59:08 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠