Skip to content

Commit 4155c08

Browse files
tgale96copybara-github
authored andcommitted
Support configurable LMUL in VectorXoshiro.
PiperOrigin-RevId: 620864942
1 parent 4ce48ca commit 4155c08

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

hwy/contrib/random/random-inl.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,21 @@ class Xoshiro {
170170

171171
} // namespace internal
172172

173+
template <int kPow2 = 1>
173174
class VectorXoshiro {
174175
private:
175-
using VU64 = Vec<ScalableTag<std::uint64_t>>;
176+
using TagU64 = ScalableTag<std::uint64_t, kPow2>;
177+
using TagF64 = ScalableTag<double, kPow2>;
178+
179+
using VU64 = Vec<TagU64>;
176180
using StateType = AlignedNDArray<std::uint64_t, 2>;
177181
#if HWY_HAVE_FLOAT64
178-
using VF64 = Vec<ScalableTag<double>>;
182+
using VF64 = Vec<TagF64>;
179183
#endif
180184
public:
181185
explicit VectorXoshiro(const std::uint64_t seed,
182186
const std::uint64_t threadNumber = 0)
183-
: state_{{internal::Xoshiro::StateSize(),
184-
Lanes(ScalableTag<std::uint64_t>{})}},
187+
: state_{{internal::Xoshiro::StateSize(), Lanes(TagU64{})}},
185188
streams{state_.shape().back()} {
186189
internal::Xoshiro xoshiro{seed};
187190

@@ -202,7 +205,7 @@ class VectorXoshiro {
202205

203206
AlignedVector<std::uint64_t> operator()(const std::size_t n) {
204207
AlignedVector<std::uint64_t> result(n);
205-
const ScalableTag<std::uint64_t> tag{};
208+
const TagU64 tag{};
206209
auto s0 = Load(tag, state_[{0}].data());
207210
auto s1 = Load(tag, state_[{1}].data());
208211
auto s2 = Load(tag, state_[{2}].data());
@@ -221,7 +224,7 @@ class VectorXoshiro {
221224
template <std::uint64_t N>
222225
std::array<std::uint64_t, N> operator()() noexcept {
223226
alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
224-
const ScalableTag<std::uint64_t> tag{};
227+
const TagU64 tag{};
225228
auto s0 = Load(tag, state_[{0}].data());
226229
auto s1 = Load(tag, state_[{1}].data());
227230
auto s2 = Load(tag, state_[{2}].data());
@@ -246,7 +249,7 @@ class VectorXoshiro {
246249
#if HWY_HAVE_FLOAT64
247250

248251
HWY_INLINE VF64 Uniform() noexcept {
249-
const ScalableTag<double> real_tag{};
252+
const TagF64 real_tag{};
250253
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
251254
const auto bits = ShiftRight<11>(Next());
252255
const auto real = ConvertTo(real_tag, bits);
@@ -255,8 +258,8 @@ class VectorXoshiro {
255258

256259
AlignedVector<double> Uniform(const std::size_t n) {
257260
AlignedVector<double> result(n);
258-
const ScalableTag<std::uint64_t> tag{};
259-
const ScalableTag<double> real_tag{};
261+
const TagU64 tag{};
262+
const TagF64 real_tag{};
260263
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
261264

262265
auto s0 = Load(tag, state_[{0}].data());
@@ -282,8 +285,8 @@ class VectorXoshiro {
282285
template <std::uint64_t N>
283286
std::array<double, N> Uniform() noexcept {
284287
alignas(HWY_ALIGNMENT) std::array<double, N> result;
285-
const ScalableTag<std::uint64_t> tag{};
286-
const ScalableTag<double> real_tag{};
288+
const TagU64 tag{};
289+
const TagF64 real_tag{};
287290
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
288291

289292
auto s0 = Load(tag, state_[{0}].data());
@@ -326,7 +329,7 @@ class VectorXoshiro {
326329
}
327330

328331
HWY_INLINE VU64 Next() noexcept {
329-
const ScalableTag<std::uint64_t> tag{};
332+
const TagU64 tag{};
330333
auto s0 = Load(tag, state_[{0}].data());
331334
auto s1 = Load(tag, state_[{1}].data());
332335
auto s2 = Load(tag, state_[{2}].data());
@@ -368,7 +371,7 @@ class CachedXoshiro {
368371
}
369372

370373
private:
371-
VectorXoshiro generator_;
374+
VectorXoshiro</*kPow2=*/1> generator_;
372375
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
373376
std::size_t index_;
374377

hwy/contrib/random/random_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ std::uint64_t GetSeed() { return static_cast<uint64_t>(std::time(nullptr)); }
3030
void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
3131
const size_t size) {
3232
const ScalableTag<std::uint64_t> d;
33-
VectorXoshiro generator{seed};
33+
VectorXoshiro<> generator{seed};
3434
for (size_t i = 0; i < size; i += Lanes(d)) {
3535
Store(generator(), d, result + i);
3636
}
@@ -40,7 +40,7 @@ void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
4040
void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,
4141
const size_t size) {
4242
const ScalableTag<double> d;
43-
VectorXoshiro generator{seed};
43+
VectorXoshiro<> generator{seed};
4444
for (size_t i = 0; i < size; i += Lanes(d)) {
4545
Store(generator.Uniform(), d, result + i);
4646
}
@@ -49,7 +49,7 @@ void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,
4949

5050
void TestSeeding() {
5151
const std::uint64_t seed = GetSeed();
52-
VectorXoshiro generator{seed};
52+
VectorXoshiro<> generator{seed};
5353
internal::Xoshiro reference{seed};
5454
const auto& state = generator.GetState();
5555
const ScalableTag<std::uint64_t> d;
@@ -72,7 +72,7 @@ void TestSeeding() {
7272
void TestMultiThreadSeeding() {
7373
const std::uint64_t seed = GetSeed();
7474
const std::uint64_t threadId = std::random_device()() % 1000;
75-
VectorXoshiro generator{seed, threadId};
75+
VectorXoshiro<> generator{seed, threadId};
7676
internal::Xoshiro reference{seed};
7777

7878
for (std::size_t i = 0UL; i < threadId; ++i) {
@@ -146,7 +146,7 @@ void TestUniformDist() {
146146

147147
void TestNextNRandomUint64() {
148148
const std::uint64_t seed = GetSeed();
149-
VectorXoshiro generator{seed};
149+
VectorXoshiro<> generator{seed};
150150
const auto result_array = generator.operator()(tests);
151151
std::vector<internal::Xoshiro> reference;
152152
reference.emplace_back(seed);
@@ -174,7 +174,7 @@ void TestNextNRandomUint64() {
174174

175175
void TestNextFixedNRandomUint64() {
176176
const std::uint64_t seed = GetSeed();
177-
VectorXoshiro generator{seed};
177+
VectorXoshiro<> generator{seed};
178178
const auto result_array = generator.operator()<tests>();
179179
std::vector<internal::Xoshiro> reference;
180180
reference.emplace_back(seed);
@@ -203,7 +203,7 @@ void TestNextFixedNRandomUint64() {
203203
#if HWY_HAVE_FLOAT64
204204
void TestNextNUniformDist() {
205205
const std::uint64_t seed = GetSeed();
206-
VectorXoshiro generator{seed};
206+
VectorXoshiro<> generator{seed};
207207
const auto result_array = generator.Uniform(tests);
208208
internal::Xoshiro reference{seed};
209209
const ScalableTag<double> d;
@@ -222,7 +222,7 @@ void TestNextNUniformDist() {
222222

223223
void TestNextFixedNUniformDist() {
224224
const std::uint64_t seed = GetSeed();
225-
VectorXoshiro generator{seed};
225+
VectorXoshiro<> generator{seed};
226226
const auto result_array = generator.Uniform<tests>();
227227
internal::Xoshiro reference{seed};
228228
const ScalableTag<double> d;

0 commit comments

Comments
 (0)