#include <cstdint>
#include <iostream>
#include <chrono>
#include <immintrin.h>
#include <tuple>
#pragma GCC target("fma")
#pragma GCC target("avx2")
// 模板快速幂
template <typename T, typename T1>
constexpr T qpow(T m, T1 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
}
m *= m;
n >>= 1;
}
return result;
}
template <typename T>
constexpr int int_log2(const T &n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r)
{
int mid = (l + r) / 2;
if ((T(1) << mid) > n)
{
r = mid;
}
else
{
l = mid;
}
}
return l;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// bits个二进制全为1的数,等于2^bits-1
template <typename T>
constexpr T all_one(int bits)
{
T tmp = T(1) << (bits - 1);
return tmp - 1 + tmp;
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
template <uint32_t MOD>
struct MontInt32
{
static constexpr uint64_t R = uint64_t(1) << 32;
static constexpr uint32_t R_MASK = R - 1;
static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
static constexpr uint32_t MOD2 = MOD * 2;
static_assert(int_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
static_assert(uint32_t(MOD_INV * MOD) == 1, "Montgomery32 modulus is not correct");
uint32_t data;
constexpr MontInt32() : data(0) {}
constexpr MontInt32(uint32_t n) : data(toMont(n)) {}
static constexpr uint32_t toMont(uint32_t n)
{
return (uint64_t(n) << 32) % MOD;
}
static constexpr uint32_t redc(uint64_t input)
{
uint64_t n = uint32_t(input) * MOD_INV_NEG;
n = n * MOD + input;
n >>= 32;
return n < MOD ? n : n - MOD;
}
static constexpr uint32_t toInt(uint32_t n)
{
return redc(n);
}
static constexpr uint32_t addMont(uint32_t m, uint32_t n)
{
n = m + n;
return n < MOD ? n : n - MOD;
}
static constexpr uint32_t subMont(uint32_t m, uint32_t n)
{
n = m - n;
return n > m ? n + MOD : n;
}
static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
{
return redc(uint64_t(m) * n);
}
constexpr void fromInt(uint32_t n)
{
data = toMont(n);
}
constexpr uint32_t toInt() const
{
return toInt(data);
}
constexpr operator uint32_t() const
{
return toInt();
}
constexpr MontInt32 operator+(MontInt32 rhs) const
{
rhs.data = addMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator-(MontInt32 rhs) const
{
rhs.data = subMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator*(MontInt32 rhs) const
{
rhs.data = mulMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 &operator+=(const MontInt32 &rhs)
{
data = addMont(data, rhs.data);
return *this;
}
constexpr MontInt32 &operator-=(const MontInt32 &rhs)
{
data = subMont(data, rhs.data);
return *this;
}
constexpr MontInt32 &operator*=(const MontInt32 &rhs)
{
data = mulMont(data, rhs.data);
return *this;
}
static constexpr uint32_t mod()
{
return MOD;
}
};
template <uint32_t MOD>
struct MontInt32X8
{
using MontInt = MontInt32<MOD>;
__m256i data;
MontInt32X8() { data = _mm256_setzero_si256(); }
MontInt32X8(MontInt x) { data = _mm256_set1_epi32(x.data); }
MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
{
data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
}
MontInt32X8(__m256i rhs) : data(rhs) {}
template <typename T>
MontInt32X8(const T *p)
{
loadu(p);
}
static constexpr uint32_t mod()
{
return MOD;
}
static MontInt32X8 zeroX8()
{
return _mm256_setzero_si256();
}
static MontInt32X8 modX8()
{
return _mm256_set1_epi32(MOD);
}
static MontInt32X8 mod1X8()
{
return _mm256_set1_epi32(MOD - 1);
}
static MontInt32X8 mod2X8()
{
return _mm256_set1_epi32(MOD * 2);
}
static MontInt32X8 modNX8()
{
return _mm256_set1_epi32(MontInt::MOD_INV_NEG);
}
MontInt32X8 mul64(MontInt32X8 rhs) const
{
return _mm256_mul_epu32(data, rhs.data);
}
MontInt32X8 lShift64(int n) const
{
return _mm256_slli_epi64(data, n);
}
MontInt32X8 rShift64(int n) const
{
return _mm256_srli_epi64(data, n);
}
MontInt32X8 evenElements() const
{
return blend<0b10101010>(data, zeroX8());
}
MontInt32X8 oddElements() const
{
return blend<0b01010101>(data, zeroX8());
}
std::pair<MontInt32X8, MontInt32X8> mul64hl(MontInt32X8 rhs) const
{
return std::make_pair(mul64(rhs), rShift64(32).mul64(rhs.rShift64(32)));
}
// MontInt32X8 getW1() const
// {
// alignas(32) uint32_t temp[8];
// store(temp);
// for (auto &&i : temp)
// {
// i = (uint64_t(i) << 32) / mod;
// }
// return MontInt32X8(temp);
// }
// MontInt32X8 mulModShoup(MontInt32X8 w, MontInt32X8 w1) const
// {
// MontInt32X8 q0, q1, t0, t1;
// std::tie(q0, q1) = mul64hl(w1);
// std::tie(t0, t1) = mul64hl(w);
// q0 = q0.rShift64(32), q1 = q1.rShift64(32);
// q0 = q0.mul64(MontInt32X8(mod)), q1 = q1.mul64(MontInt32X8(mod));
// t0 = t0.rawSub64(q0), t1 = t1.rawSub64(q1);
// return t0 | t1.lShift64(32);
// }
static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
{
MontInt32X8 p0 = even64.mul64(modNX8());
MontInt32X8 p1 = odd64.mul64(modNX8());
p0 = p0.mul64(modX8()).rawAdd64(even64).rShift64(32);
p1 = p1.mul64(modX8()).rawAdd64(odd64);
return blend<0b10101010>(p0, p1);
}
static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
{
return montRedcLazy(even64, odd64).largeNorm();
}
MontInt32X8 toMont() const
{
alignas(32) uint32_t temp[8];
store(temp);
for (auto &&i : temp)
{
i = (uint64_t(i) << 32) % MOD;
}
return MontInt32X8(temp);
}
MontInt32X8 toInt() const
{
MontInt32X8 e = evenElements();
MontInt32X8 o = rShift64(32);
return montRedc(e, o);
}
template <int N>
static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
{
return _mm256_blend_epi32(a.data, b.data, N);
}
MontInt32X8 largeNorm() const
{
MontInt32X8 sub = (*this > mod1X8()) & modX8();
return rawSub(sub);
}
MontInt32X8 smallNorm() const
{
MontInt32X8 add = (zeroX8() > *this) & modX8();
return rawAdd(add);
}
MontInt32X8 smallNorm2() const
{
MontInt32X8 add = (zeroX8() > *this) & mod2X8();
return rawAdd(add);
}
MontInt32X8 addMont(MontInt32X8 rhs) const
{
return rawSub(mod2X8()).rawAdd(rhs).smallNorm2();
}
MontInt32X8 subMont(MontInt32X8 rhs) const
{
return rawSub(rhs).smallNorm2();
}
MontInt32X8 mulMont(MontInt32X8 rhs) const
{
auto mulhl = mul64hl(rhs);
return montRedc(mulhl.first, mulhl.second);
}
MontInt32X8 operator+(MontInt32X8 rhs) const
{
return addMont(rhs);
}
MontInt32X8 operator-(MontInt32X8 rhs) const
{
return subMont(rhs);
}
MontInt32X8 operator*(MontInt32X8 rhs) const
{
return mulMont(rhs);
}
MontInt32X8 rawAdd(MontInt32X8 rhs) const
{
return _mm256_add_epi32(data, rhs.data);
}
MontInt32X8 rawSub(MontInt32X8 rhs) const
{
return _mm256_sub_epi32(data, rhs.data);
}
MontInt32X8 rawAdd64(MontInt32X8 rhs) const
{
return _mm256_add_epi64(data, rhs.data);
}
MontInt32X8 rawSub64(MontInt32X8 rhs) const
{
return _mm256_sub_epi64(data, rhs.data);
}
MontInt32X8 operator>(MontInt32X8 n) const
{
return _mm256_cmpgt_epi32(data, n.data);
}
MontInt32X8 operator<(MontInt32X8 n) const
{
return n > *this;
}
MontInt32X8 operator==(MontInt32X8 n) const
{
return _mm256_cmpeq_epi32(data, n.data);
}
MontInt32X8 operator&(MontInt32X8 n) const
{
return _mm256_and_si256(data, n.data);
}
MontInt32X8 operator|(MontInt32X8 n) const
{
return _mm256_or_si256(data, n.data);
}
MontInt32X8 operator^(MontInt32X8 n) const
{
return _mm256_xor_si256(data, n.data);
}
template <typename T>
void loadu(const T *p)
{
data = _mm256_loadu_si256((const __m256i *)p);
}
template <typename T>
void load(const T *p)
{
data = _mm256_load_si256((const __m256i *)p);
// data = *reinterpret_cast<const __m256i *>(p);
}
template <typename T>
void storeu(T *p) const
{
_mm256_storeu_si256((__m256i *)p, data);
}
template <typename T>
void store(T *p) const
{
_mm256_store_si256((__m256i *)p, data);
// *reinterpret_cast<__m256i *>(p) = data;
}
void printI32() const
{
alignas(32) int32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU32() const
{
alignas(32) uint32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU64() const
{
alignas(32) uint64_t v[4];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3] << "]" << std::endl;
}
};
// void test_x8()
// {
// constexpr size_t lo = 1e8;
// constexpr uint32_t mod = 998244353;
// MontInt32<mod> a(3), b(mod - 2);
// MontInt32X8<mod> c(a), d(b);
// MontInt32X8<mod> e(c), f(d);
// auto t1 = std::chrono::steady_clock::now();
// for (size_t i = 0; i < lo; i++)
// {
// // a = a * b;
// }
// auto t2 = std::chrono::steady_clock::now();
// for (size_t i = 0; i < lo; i++)
// {
// a = a - b;
// }
// auto t3 = std::chrono::steady_clock::now();
// // c = c.toMont();
// // d = d.toMont();
// for (size_t i = 0; i < lo; i++)
// {
// c = c - d;
// }
// c = c.toInt();
// f = e.toInt();
// auto t4 = std::chrono::steady_clock::now();
// std::cout << uint32_t(a.toInt()) << "\n";
// c.printU32();
// e.printU32();
// auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
// auto time2 = std::chrono::duration_cast<std::chrono::duration<double>>(t3 - t2).count();
// auto time3 = std::chrono::duration_cast<std::chrono::duration<double>>(t4 - t3).count();
// std::cout << time1 << "s " << time2 << "s " << time3 << "s\n";
// }
void avx2_test()
{
constexpr size_t len = 1 << 5;
constexpr uint32_t mod = 998244353;
using NTTX8 = MontInt32X8<mod>;
using ModInt = typename NTTX8::MontInt;
// alignas(64) static ModInt a[len];
auto a = (ModInt *)_mm_malloc(len * sizeof(ModInt), 32);
for (size_t i = 0; i < len; i++)
{
a[i].data = i;
// b[i] = i;
}
size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
auto t1 = std::chrono::steady_clock::now();
for (size_t i = 0; i < times; i++)
{
__m256i x = _mm256_set1_epi32(2);
x = _mm256_add_epi32(x, x);
// NTTX8 x;
// x.data = *(__m256i *)a;
// x = x * x;
*(__m256i *)a = x;
}
auto t2 = std::chrono::steady_clock::now();
auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
{
std::cout << i << ":\t" << uint32_t(a[i].data) << "\n";
}
std::cout << time1 << "\n";
}
int main()
{
// test_x8();
avx2_test();
return 0;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 34.22 us | 32 KB | Runtime Error | Score: 0 | 显示更多 |