Skip to content

Commit d9944ac

Browse files
committed
Add nim product test
1 parent 7c47b8c commit d9944ac

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed

cp-algo/number_theory/nimber.hpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#ifndef CP_ALGO_NUMBER_THEORY_NIMBER_HPP
2+
#define CP_ALGO_NUMBER_THEORY_NIMBER_HPP
3+
#include <array>
4+
#include <bit>
5+
#include <cstdint>
6+
#include <immintrin.h>
7+
// Ensure PCLMULQDQ is available at compile time
8+
#if defined(__PCLMUL__)
9+
static constexpr bool CP_ALGO_HAS_PCLMUL = true;
10+
#else
11+
static constexpr bool CP_ALGO_HAS_PCLMUL = false;
12+
#endif
13+
static_assert(CP_ALGO_HAS_PCLMUL,
14+
"PCLMULQDQ intrinsics not available. Enable it with '-mpclmul' or add '#pragma GCC target(\"pclmul\")' or compile with '-march=native' on supported CPUs.");
15+
16+
namespace cp_algo::math::nimber {
17+
inline constexpr std::array<uint64_t, 64> BASIS_COL = {
18+
0x0000000000000001ull, 0x5211145c804b6109ull, 0x7c8bc2cad259879full, 0x565854b4c60c1e0bull,
19+
0x4068acf7104c20c3ull, 0x662d2bd0f2739155ull, 0x7a90c83701fa8323ull, 0x21cfa750247e8755ull,
20+
0x67d1044e545abf47ull, 0x4d9d3b5a8568f839ull, 0x567a9d7331b6b3c6ull, 0x1ca54bfdd6d1ae59ull,
21+
0x454fa483275db25cull, 0x6766df6fec4e9d44ull, 0x35cb621cec1fe7f9ull, 0x4c606d3e52faf263ull,
22+
0x57640dc825a57954ull, 0x7aca87838b7f6315ull, 0x6d53c884ebf2b0edull, 0x3721d998bb50164bull,
23+
0x7aa7c62fd6cd53abull, 0x47cbb2c51f7c040full, 0x132063b7f5e42489ull, 0x0c1b36c8b2993f8aull,
24+
0x60119ecff680497aull, 0x5175da444cc11791ull, 0x5792ff4554765b09ull, 0x0c9fdb8a01334e82ull,
25+
0x2be0a763a68a4725ull, 0x3c2dc8260ad051f6ull, 0x6c4c9fed8816bb9cull, 0x630062753ffaf766ull,
26+
0x7b37d31b5d519225ull, 0x2364f7f79705691cull, 0x453eb8a83e2fec71ull, 0x7c0121b37e828666ull,
27+
0x59190d3250e66011ull, 0x103207f9dda18caeull, 0x28233dce01c69b76ull, 0x4fa519899227a5e7ull,
28+
0x4567ba46ee7bc6cdull, 0x0a284773d021afd5ull, 0x63894079bbe3a824ull, 0x11013c7fdfaaa5c2ull,
29+
0x1aa984f18574f3b0ull, 0x0cbaba126fd0c4dbull, 0x0b8797719e6dc725ull, 0x4a2845680aefaa72ull,
30+
0x536d2535f6934e15ull, 0x01db7a57effcd689ull, 0x7e1ed0ad01e2a5adull, 0x0aedc9b3cee826f6ull,
31+
0x7ba716eccf9f68e1ull, 0x5d5e23bc0f3dc38full, 0x0b5f2a3b88674d83ull, 0x2de9bafc2f00f8d4ull,
32+
0x3b56712ad419c7e0ull, 0x3ab4be8c30c19253ull, 0x2708522ffaa654b0ull, 0x2b8bca57bf643598ull,
33+
0x588825d1a5fa8e1cull, 0x86adf8bf4d45962full, 0x51b4c15d8719dd73ull, 0xe4a2b3b59783d0aaull
34+
};
35+
36+
inline constexpr std::array<uint64_t, 64> INV_COL = {
37+
0x0000000000000001ull, 0x19c9369f278adc02ull, 0xa181e7d66f5ff795ull, 0x5db84357ce785d09ull,
38+
0xa0bae2f9d2430cc8ull, 0xb7ea5a9705b771c0ull, 0xba4f3cd82801769dull, 0x4886cde01b8241d0ull,
39+
0x0a6f43f2aaf612edull, 0xebd0142f98030a32ull, 0xa81f89cda43f3792ull, 0xe99aec6b66ccb814ull,
40+
0xa69d1ff025fc2f82ull, 0x48a81132d25db068ull, 0x4a900f9dcaa9644full, 0xe5ce4ea88259972aull,
41+
0xf7094c336029f04cull, 0xe191dde287bc9c6bull, 0xaacaff12bff239b8ull, 0x49bc5212be1bc1caull,
42+
0xfe57defb454446cfull, 0xa1dffcf944bdf6a7ull, 0xb9f1bdb5cee941eeull, 0x12e5e889275c22deull,
43+
0x5bcb6b117b77eeedull, 0x03eb1ab59d05ae4bull, 0x02a25d7076ddd386ull, 0x53164a606c612245ull,
44+
0xebb33f5822f66059ull, 0xe9be765f5747b93eull, 0x552a78df373a354full, 0xbcf5ac65f31fb8bfull,
45+
0xe411e728becdc77bull, 0xf35c26d7b57cdca6ull, 0x4499da83de4ca5f7ull, 0x40ab25bdca4ae226ull,
46+
0xee004b6f1dff7218ull, 0x0d122da9821c5b41ull, 0x51fbfcb058120efeull, 0xa148b1fa84905b22ull,
47+
0xbb8ed3e647604d8dull, 0xe2d93fef2472776full, 0x4c17a2541a10e6b5ull, 0x1d879e08903708e7ull,
48+
0x0fbe7d0d1934da90ull, 0x5bf977d9c6f61d30ull, 0x06832fc918260412ull, 0x0fe22e843ebf73e3ull,
49+
0x4d7ef4e4fa28d60dull, 0x402250d979afbed5ull, 0x067902b8c8ca2d4full, 0xf38d113fe1d6bb16ull,
50+
0x414f0248b02b5b7dull, 0xf041922915824ce9ull, 0x11a72fb5e30c93d9ull, 0x12e54f4d63102aeeull,
51+
0xbc46ac14b3141c6cull, 0x1f172b3c16c645bbull, 0x584b492ed4e8fa6cull, 0x00a852e9a32cc133ull,
52+
0xa180861bce00a45eull, 0xa194b6bcb4645fb9ull, 0x4509002ad808a4fbull, 0xc5172a0055602f69ull
53+
};
54+
55+
template <const auto& COLS>
56+
consteval auto make_byte_tables() {
57+
std::array<std::array<uint64_t, 1 << 8>, 8> T{};
58+
for (int pos = 0; pos < 8; pos++) {
59+
for (int col = 0; col < 8; col++) {
60+
for (int mask = 0; mask < (1 << col); mask++) {
61+
T[pos][mask | (1 << col)] = T[pos][mask] ^ COLS[pos * 8 + col];
62+
}
63+
}
64+
}
65+
return T;
66+
}
67+
68+
inline constexpr auto INV_BYTE = make_byte_tables<INV_COL>();
69+
inline constexpr auto BASIS_BYTE = make_byte_tables<BASIS_COL>();
70+
71+
[[gnu::always_inline]]
72+
inline uint64_t nim_to_poly(uint64_t x) {
73+
auto xb = std::bit_cast<std::array<uint8_t, 8>>(x);
74+
return INV_BYTE[0][xb[0]] ^ INV_BYTE[1][xb[1]]
75+
^ INV_BYTE[2][xb[2]] ^ INV_BYTE[3][xb[3]]
76+
^ INV_BYTE[4][xb[4]] ^ INV_BYTE[5][xb[5]]
77+
^ INV_BYTE[6][xb[6]] ^ INV_BYTE[7][xb[7]];
78+
}
79+
80+
[[gnu::always_inline]]
81+
inline uint64_t poly_to_nim(uint64_t c) {
82+
auto cb = std::bit_cast<std::array<uint8_t, 8>>(c);
83+
return BASIS_BYTE[0][cb[0]] ^ BASIS_BYTE[1][cb[1]]
84+
^ BASIS_BYTE[2][cb[2]] ^ BASIS_BYTE[3][cb[3]]
85+
^ BASIS_BYTE[4][cb[4]] ^ BASIS_BYTE[5][cb[5]]
86+
^ BASIS_BYTE[6][cb[6]] ^ BASIS_BYTE[7][cb[7]];
87+
}
88+
89+
// Carryless multiply over GF(2) using PCLMULQDQ
90+
[[gnu::always_inline]]
91+
inline __m128i clmul(int64_t a, int64_t b) {
92+
return _mm_clmulepi64_si128(__m128i{a, 0}, __m128i{b, 0}, 0);
93+
}
94+
95+
// Reduce modulo x^64 + x^4 + x^3 + x + 1
96+
[[gnu::always_inline]]
97+
inline uint64_t reduce_mod(__m128i v) {
98+
v[0] ^= v[1] ^ (v[1] << 1) ^ (v[1] << 3) ^ (v[1] << 4);
99+
static constexpr auto RED_OVER = [] {
100+
std::array<uint64_t, 16> red{};
101+
for (int q = 0; q < 16; ++q) {
102+
uint64_t o = q ^ (q >> 1) ^ (q >> 3);
103+
red[q] = o ^ (o << 1) ^ (o << 3) ^ (o << 4);
104+
}
105+
return red;
106+
}();
107+
return v[0] ^ RED_OVER[v[1] >> 60];
108+
}
109+
110+
// Public nimber product via isomorphism (no recursion, no Gauss at runtime)
111+
[[gnu::always_inline]]
112+
inline uint64_t nim_mul(uint64_t a, uint64_t b) {
113+
uint64_t pa = nim_to_poly(a);
114+
uint64_t pb = nim_to_poly(b);
115+
auto prod = clmul(pa, pb);
116+
uint64_t red = reduce_mod(prod);
117+
return poly_to_nim(red);
118+
}
119+
}
120+
121+
#endif // CP_ALGO_NUMBER_THEORY_NIMBER_HPP

cp-algo/util/simd.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,23 @@ CP_ALGO_SIMD_PRAGMA_PUSH
1919
namespace cp_algo {
2020
template<typename T, size_t len>
2121
using simd [[gnu::vector_size(len * sizeof(T))]] = T;
22+
using u64x8 = simd<uint64_t, 8>;
23+
using u32x16 = simd<uint32_t, 16>;
2224
using i64x4 = simd<int64_t, 4>;
2325
using u64x4 = simd<uint64_t, 4>;
2426
using u32x8 = simd<uint32_t, 8>;
27+
using u16x16 = simd<uint16_t, 16>;
2528
using i32x4 = simd<int32_t, 4>;
2629
using u32x4 = simd<uint32_t, 4>;
30+
using u16x8 = simd<uint16_t, 8>;
31+
using u16x4 = simd<uint16_t, 4>;
2732
using i16x4 = simd<int16_t, 4>;
2833
using u8x32 = simd<uint8_t, 32>;
34+
using u8x8 = simd<uint8_t, 8>;
35+
using u8x4 = simd<uint8_t, 4>;
2936
using dx4 = simd<double, 4>;
3037

31-
dx4 abs(dx4 a) {
38+
inline dx4 abs(dx4 a) {
3239
return dx4{
3340
std::abs(a[0]),
3441
std::abs(a[1]),
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// @brief Nim Product
2+
#define PROBLEM "https://judge.yosupo.jp/problem/nim_product_64"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt,pclmul")
5+
#define CP_ALGO_CHECKPOINT
6+
#include <iostream>
7+
#include "blazingio/blazingio.min.hpp"
8+
#include <immintrin.h>
9+
#include "cp-algo/util/big_alloc.hpp"
10+
#include "cp-algo/util/checkpoint.hpp"
11+
#include "cp-algo/number_theory/nimber.hpp"
12+
#include <bits/stdc++.h>
13+
14+
using namespace std;
15+
using namespace cp_algo::math::nimber;
16+
17+
int main() {
18+
ios::sync_with_stdio(0);
19+
cin.tie(0);
20+
int t;
21+
cin >> t;
22+
cp_algo::big_vector<uint64_t> A(t), B(t);
23+
for (int i = 0; i < t; i++) {
24+
cin >> A[i] >> B[i];
25+
}
26+
cp_algo::checkpoint("read");
27+
for (int i = 0; i < t; i++) {
28+
A[i] = nim_to_poly(A[i]);
29+
B[i] = nim_to_poly(B[i]);
30+
}
31+
cp_algo::checkpoint("to_poly");
32+
for (int i = 0; i < t; i++) {
33+
A[i] = reduce_mod(clmul(A[i], B[i]));
34+
}
35+
cp_algo::checkpoint("clmul+reduce");
36+
for (int i = 0; i < t; i++) {
37+
A[i] = poly_to_nim(A[i]);
38+
}
39+
cp_algo::checkpoint("to_nim");
40+
for (int i = 0; i < t; i++) {
41+
cout << A[i] << "\n";
42+
}
43+
cp_algo::checkpoint("print");
44+
cp_algo::checkpoint<1>();
45+
}

0 commit comments

Comments
 (0)