Skip to content

Commit a470ba6

Browse files
authored
transpose: implement large1D twiddle multiply for length < 256
* transpose: implement large1D twiddle multiply for length < 256 * rocfft-test: remove 1D prime sizes that are covered by radX * rocfft-test: allow array type + placement to be overridden by test suites * rocfft-test: test all 1D C2C sizes < 8k
1 parent 7e4c5f0 commit a470ba6

File tree

4 files changed

+140
-3
lines changed

4 files changed

+140
-3
lines changed

clients/tests/accuracy_test.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ inline auto param_generator_base(const std::vector<rocfft_transform_type>& typ
403403
const std::vector<std::vector<size_t>>& v_lengths,
404404
const std::vector<rocfft_precision>& precision_range,
405405
const std::vector<size_t>& batch_range,
406+
decltype(generate_types) types_generator,
406407
const stride_generator& istride,
407408
const stride_generator& ostride,
408409
const std::vector<std::vector<size_t>>& ioffset_range,
@@ -424,7 +425,7 @@ inline auto param_generator_base(const std::vector<rocfft_transform_type>& typ
424425
{
425426
for(const auto batch : batch_range)
426427
{
427-
for(const auto& types : generate_types(transform_type, place_range))
428+
for(const auto& types : types_generator(transform_type, place_range))
428429
{
429430
for(const auto& istride_dist : istride.generate(lengths, batch))
430431
{
@@ -485,6 +486,7 @@ inline auto param_generator(const std::vector<std::vector<size_t>>& v_length
485486
v_lengths,
486487
precision_range,
487488
batch_range,
489+
generate_types,
488490
istride,
489491
ostride,
490492
ioffset_range,
@@ -507,6 +509,7 @@ inline auto param_generator_complex(const std::vector<std::vector<size_t>>&
507509
v_lengths,
508510
precision_range,
509511
batch_range,
512+
generate_types,
510513
istride,
511514
ostride,
512515
ioffset_range,
@@ -529,6 +532,7 @@ inline auto param_generator_real(const std::vector<std::vector<size_t>>& v_l
529532
v_lengths,
530533
precision_range,
531534
batch_range,
535+
generate_types,
532536
istride,
533537
ostride,
534538
ioffset_range,

clients/tests/accuracy_test_1D.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,32 @@ const static std::vector<size_t> mix_range
5353
900, 1250, 1500, 1875, 2160, 2187, 2250, 2500, 3000, 4000, 12000, 24000, 72000};
5454

5555
const static std::vector<size_t> prime_range
56-
= {7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97};
56+
= {17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97};
57+
58+
static std::vector<size_t> small_1D_sizes()
59+
{
60+
static const size_t SMALL_1D_MAX = 8192;
61+
62+
// generate a list of sizes from 2 and up, skipping any sizes that are already covered
63+
std::vector<size_t> covered_sizes;
64+
std::copy(pow2_range.begin(), pow2_range.end(), std::back_inserter(covered_sizes));
65+
std::copy(pow3_range.begin(), pow3_range.end(), std::back_inserter(covered_sizes));
66+
std::copy(pow5_range.begin(), pow5_range.end(), std::back_inserter(covered_sizes));
67+
std::copy(radX_range.begin(), radX_range.end(), std::back_inserter(covered_sizes));
68+
std::copy(mix_range.begin(), mix_range.end(), std::back_inserter(covered_sizes));
69+
std::copy(prime_range.begin(), prime_range.end(), std::back_inserter(covered_sizes));
70+
std::sort(covered_sizes.begin(), covered_sizes.end());
71+
72+
std::vector<size_t> output;
73+
for(size_t i = 2; i < SMALL_1D_MAX; ++i)
74+
{
75+
if(!std::binary_search(covered_sizes.begin(), covered_sizes.end(), i))
76+
{
77+
output.push_back(i);
78+
}
79+
}
80+
return output;
81+
}
5782

5883
const static std::vector<std::vector<size_t>> stride_range = {{1}};
5984

@@ -225,6 +250,30 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_offset_mix_1D,
225250
place_range)),
226251
accuracy_test::TestName);
227252

253+
// small 1D sizes just need to make sure our factorization isn't
254+
// completely broken, so we just check simple C2C outplace interleaved
255+
INSTANTIATE_TEST_SUITE_P(small_1D,
256+
accuracy_test,
257+
::testing::ValuesIn(param_generator_base(
258+
{rocfft_transform_type_complex_forward},
259+
{small_1D_sizes()},
260+
{rocfft_precision_single},
261+
{1},
262+
[](rocfft_transform_type t,
263+
const std::vector<rocfft_result_placement>& place_range) {
264+
return std::vector<type_place_io_t>{
265+
std::make_tuple(t,
266+
place_range[0],
267+
rocfft_array_type_complex_interleaved,
268+
rocfft_array_type_complex_interleaved)};
269+
},
270+
stride_range,
271+
stride_range,
272+
ioffset_range_zero,
273+
ooffset_range_zero,
274+
{rocfft_placement_notinplace})),
275+
accuracy_test::TestName);
276+
228277
// NB:
229278
// We have known non-unit strides issues for 1D:
230279
// - C2C middle size(for instance, single precision, 8192)

library/src/device/kernels/transpose.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,18 @@
3030
#define TRANSPOSE_TWIDDLE_MUL(tmp) \
3131
if(WITH_TWL) \
3232
{ \
33-
if(TWL == 2) \
33+
if(TWL == 1) \
34+
{ \
35+
if(DIR == -1) \
36+
{ \
37+
TWIDDLE_STEP_MUL_FWD(TWLstep1, twiddles_large, (gx + tx1) * (gy + ty1 + i), tmp); \
38+
} \
39+
else \
40+
{ \
41+
TWIDDLE_STEP_MUL_INV(TWLstep1, twiddles_large, (gx + tx1) * (gy + ty1 + i), tmp); \
42+
} \
43+
} \
44+
else if(TWL == 2) \
3445
{ \
3546
if(DIR == -1) \
3647
{ \

library/src/device/transpose.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,77 @@ rocfft_status rocfft_transpose_outofplace_template(size_t m,
148148
&HIP_KERNEL_NAME(
149149
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 0, 1, false, false, false>));
150150

151+
// twl=1:
152+
tmap.emplace(
153+
std::make_tuple(1, -1, true, true, true),
154+
&HIP_KERNEL_NAME(
155+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, true, true, true>));
156+
tmap.emplace(
157+
std::make_tuple(1, -1, false, true, true),
158+
&HIP_KERNEL_NAME(
159+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, false, true, true>));
160+
tmap.emplace(
161+
std::make_tuple(1, -1, true, false, true),
162+
&HIP_KERNEL_NAME(
163+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, true, false, true>));
164+
tmap.emplace(
165+
std::make_tuple(1, -1, false, false, true),
166+
&HIP_KERNEL_NAME(
167+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, false, false, true>));
168+
169+
tmap.emplace(
170+
std::make_tuple(1, 1, true, true, true),
171+
&HIP_KERNEL_NAME(
172+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, true, true, true>));
173+
tmap.emplace(
174+
std::make_tuple(1, 1, false, true, true),
175+
&HIP_KERNEL_NAME(
176+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, false, true, true>));
177+
178+
tmap.emplace(
179+
std::make_tuple(1, 1, true, false, true),
180+
&HIP_KERNEL_NAME(
181+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, true, false, true>));
182+
tmap.emplace(
183+
std::make_tuple(1, 1, false, false, true),
184+
&HIP_KERNEL_NAME(
185+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, false, false, true>));
186+
187+
tmap.emplace(
188+
std::make_tuple(1, -1, true, true, false),
189+
&HIP_KERNEL_NAME(
190+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, true, true, false>));
191+
tmap.emplace(
192+
std::make_tuple(1, -1, false, true, false),
193+
&HIP_KERNEL_NAME(
194+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, false, true, false>));
195+
tmap.emplace(
196+
std::make_tuple(1, -1, true, false, false),
197+
&HIP_KERNEL_NAME(
198+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, true, false, false>));
199+
tmap.emplace(
200+
std::make_tuple(1, -1, false, false, false),
201+
&HIP_KERNEL_NAME(
202+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, -1, false, false, false>));
203+
204+
tmap.emplace(
205+
std::make_tuple(1, 1, true, true, false),
206+
&HIP_KERNEL_NAME(
207+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, true, true, false>));
208+
tmap.emplace(
209+
std::make_tuple(1, 1, false, true, false),
210+
&HIP_KERNEL_NAME(
211+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, false, true, false>));
212+
213+
tmap.emplace(
214+
std::make_tuple(1, 1, true, false, false),
215+
&HIP_KERNEL_NAME(
216+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, true, false, false>));
217+
tmap.emplace(
218+
std::make_tuple(1, 1, false, false, false),
219+
&HIP_KERNEL_NAME(
220+
transpose_kernel2<T, TA, TB, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 1, 1, false, false, false>));
221+
151222
// twl=2:
152223
tmap.emplace(
153224
std::make_tuple(2, -1, true, true, true),
@@ -578,6 +649,8 @@ void rocfft_internal_transpose_var2(const void* data_p, void* back_p)
578649
twl = 3;
579650
else if(data->node->large1D > (size_t)256)
580651
twl = 2;
652+
else if(data->node->large1D > 0)
653+
twl = 1;
581654
else
582655
twl = 0;
583656

0 commit comments

Comments
 (0)