Skip to content

Commit 0f7e9ba

Browse files
authored
Enabled 3D_RTRTRT for cases that lengths are not aligned to 64
1 parent 8503e7f commit 0f7e9ba

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

library/src/include/tree_node.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum ComputeScheme
8383
CS_KERNEL_2D_SINGLE,
8484

8585
CS_3D_STRAIGHT,
86+
CS_3D_RTRTRT,
8687
CS_3D_RTRT,
8788
CS_3D_RC,
8889
CS_KERNEL_3D_STOCKHAM_BLOCK_CC,
@@ -259,7 +260,10 @@ class TreeNode
259260
void build_CS_2D_RC();
260261

261262
// 3D node builder:
263+
// 3D 4 node builder, R: 2D FFTs, T: transpose XY_Z, R: row FFTs, T: transpose Z_XY
262264
void build_CS_3D_RTRT();
265+
// 3D 6 node builder, R: row FFTs, T: transpose XY_Z, R: row FFTs, T: transpose XY_Z, R: row FFTs, T: transpose XY_Z
266+
void build_CS_3D_RTRTRT();
263267

264268
// State maintained while traversing the tree.
265269
//
@@ -329,6 +333,10 @@ class TreeNode
329333
OperatingBuffer& flipIn,
330334
OperatingBuffer& flipOut,
331335
OperatingBuffer& obOutBuf);
336+
void assign_buffers_CS_3D_RTRTRT(TraverseState& state,
337+
OperatingBuffer& flipIn,
338+
OperatingBuffer& flipOut,
339+
OperatingBuffer& obOutBuf);
332340

333341
// Set placement variable and in/out array types
334342
void TraverseTreeAssignPlacementsLogicA(rocfft_array_type rootIn, rocfft_array_type rootOut);
@@ -347,6 +355,7 @@ class TreeNode
347355
void assign_params_CS_2D_RTRT();
348356
void assign_params_CS_2D_RC_STRAIGHT();
349357
void assign_params_CS_3D_RTRT();
358+
void assign_params_CS_3D_RTRTRT();
350359
void assign_params_CS_3D_RC_STRAIGHT();
351360

352361
// Determine work memory requirements:

library/src/plan.cpp

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ std::string PrintScheme(ComputeScheme cs)
9090
{ENUMSTR(CS_KERNEL_2D_SINGLE)},
9191

9292
{ENUMSTR(CS_3D_STRAIGHT)},
93+
{ENUMSTR(CS_3D_RTRTRT)},
9394
{ENUMSTR(CS_3D_RTRT)},
9495
{ENUMSTR(CS_3D_RC)},
9596
{ENUMSTR(CS_KERNEL_3D_STOCKHAM_BLOCK_CC)},
@@ -825,6 +826,24 @@ void TreeNode::RecursiveBuildTree()
825826
else
826827
{
827828
scheme = CS_3D_RTRT;
829+
830+
// NB:
831+
// Try to build the 1st child but not really add it in. Switch to
832+
// CS_3D_RTRTRT if the 1st child is CS_2D_RTRT.(Any better idea?)
833+
// And enable this only for cases that lengths are not aligned to
834+
// 64 because perf issue.
835+
// See more comments in assign_params_CS_3D_RTRTRT().
836+
TreeNode* child0 = TreeNode::CreateNode(this);
837+
child0->length = length;
838+
child0->dimension = 2;
839+
child0->RecursiveBuildTree();
840+
if((child0->scheme == CS_2D_RTRT) && (length[0] % 64) && (length[1] % 64)
841+
&& (length[2] % 64))
842+
{
843+
scheme = CS_3D_RTRTRT;
844+
}
845+
846+
DeleteNode(child0);
828847
}
829848

830849
switch(scheme)
@@ -834,6 +853,11 @@ void TreeNode::RecursiveBuildTree()
834853
build_CS_3D_RTRT();
835854
}
836855
break;
856+
case CS_3D_RTRTRT:
857+
{
858+
build_CS_3D_RTRTRT();
859+
}
860+
break;
837861
case CS_3D_RC:
838862
{
839863
// 2d fft
@@ -1927,6 +1951,32 @@ void TreeNode::build_CS_3D_RTRT()
19271951
childNodes.push_back(trans2Plan);
19281952
}
19291953

1954+
void TreeNode::build_CS_3D_RTRTRT()
1955+
{
1956+
scheme = CS_3D_RTRTRT;
1957+
std::vector<size_t> cur_length = length;
1958+
1959+
for(int i = 0; i < 6; i += 2)
1960+
{
1961+
// row ffts
1962+
auto row_plan = TreeNode::CreateNode(this);
1963+
row_plan->length = cur_length;
1964+
row_plan->dimension = 1;
1965+
row_plan->RecursiveBuildTree();
1966+
childNodes.push_back(row_plan);
1967+
1968+
// transpose XY_Z
1969+
auto trans_plan = TreeNode::CreateNode(this);
1970+
trans_plan->length = cur_length;
1971+
trans_plan->scheme = CS_KERNEL_TRANSPOSE_XY_Z;
1972+
trans_plan->dimension = 2;
1973+
childNodes.push_back(trans_plan);
1974+
1975+
std::swap(cur_length[2], cur_length[1]);
1976+
std::swap(cur_length[1], cur_length[0]);
1977+
}
1978+
}
1979+
19301980
struct TreeNode::TraverseState
19311981
{
19321982
TraverseState(const ExecPlan& execPlan)
@@ -2086,6 +2136,9 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(TraverseState& state,
20862136
case CS_3D_RC:
20872137
assign_buffers_CS_RC(state, flipIn, flipOut, obOutBuf);
20882138
break;
2139+
case CS_3D_RTRTRT:
2140+
assign_buffers_CS_3D_RTRTRT(state, flipIn, flipOut, obOutBuf);
2141+
break;
20892142
default:
20902143
if(parent == nullptr)
20912144
{
@@ -2767,6 +2820,67 @@ void TreeNode::assign_buffers_CS_RC(TraverseState& state,
27672820
}
27682821
}
27692822

2823+
void TreeNode::assign_buffers_CS_3D_RTRTRT(TraverseState& state,
2824+
OperatingBuffer& flipIn,
2825+
OperatingBuffer& flipOut,
2826+
OperatingBuffer& obOutBuf)
2827+
{
2828+
assert(scheme == CS_3D_RTRTRT);
2829+
assert(childNodes.size() == 6);
2830+
2831+
obOut = obOutBuf;
2832+
2833+
// TODO: adjust buffer assignment for padding
2834+
2835+
flipIn = obIn;
2836+
flipOut = OB_TEMP;
2837+
2838+
// R
2839+
childNodes[0]->SetInputBuffer(state);
2840+
childNodes[0]->obOut = obOutBuf;
2841+
childNodes[0]->inArrayType = inArrayType;
2842+
childNodes[0]->outArrayType = outArrayType;
2843+
childNodes[0]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);
2844+
2845+
flipIn = OB_TEMP;
2846+
flipOut = obOut;
2847+
obOutBuf = obOut;
2848+
2849+
// T
2850+
childNodes[1]->SetInputBuffer(state);
2851+
childNodes[1]->obOut = OB_TEMP;
2852+
childNodes[1]->inArrayType = childNodes[0]->outArrayType;
2853+
childNodes[1]->outArrayType = rocfft_array_type_complex_interleaved;
2854+
2855+
// R
2856+
childNodes[2]->inArrayType = rocfft_array_type_complex_interleaved;
2857+
childNodes[2]->outArrayType = rocfft_array_type_complex_interleaved;
2858+
childNodes[2]->SetInputBuffer(state);
2859+
childNodes[2]->obOut = OB_TEMP;
2860+
flipIn = OB_TEMP;
2861+
flipOut = obOutBuf;
2862+
childNodes[2]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);
2863+
2864+
// T
2865+
childNodes[3]->SetInputBuffer(state);
2866+
childNodes[3]->obOut = obOutBuf;
2867+
childNodes[3]->inArrayType = rocfft_array_type_complex_interleaved;
2868+
childNodes[3]->outArrayType = outArrayType;
2869+
2870+
// R
2871+
childNodes[4]->SetInputBuffer(state);
2872+
childNodes[4]->obOut = flipIn;
2873+
childNodes[4]->TraverseTreeAssignBuffersLogicA(state, flipIn, flipOut, obOutBuf);
2874+
childNodes[4]->inArrayType = childNodes[3]->outArrayType;
2875+
childNodes[4]->outArrayType = rocfft_array_type_complex_interleaved;
2876+
2877+
// T
2878+
childNodes[5]->SetInputBuffer(state);
2879+
childNodes[5]->inArrayType = rocfft_array_type_complex_interleaved;
2880+
childNodes[5]->outArrayType = outArrayType;
2881+
childNodes[5]->obOut = obOutBuf;
2882+
}
2883+
27702884
///////////////////////////////////////////////////////////////////////////////
27712885
/// Set placement variable and in/out array types, if not already set.
27722886
void TreeNode::TraverseTreeAssignPlacementsLogicA(const rocfft_array_type rootIn,
@@ -2933,6 +3047,9 @@ void TreeNode::TraverseTreeAssignParamsLogicA()
29333047
case CS_3D_RTRT:
29343048
assign_params_CS_3D_RTRT();
29353049
break;
3050+
case CS_3D_RTRTRT:
3051+
assign_params_CS_3D_RTRTRT();
3052+
break;
29363053
case CS_3D_RC:
29373054
case CS_3D_STRAIGHT:
29383055
assign_params_CS_3D_RC_STRAIGHT();
@@ -3959,6 +4076,55 @@ void TreeNode::assign_params_CS_3D_RTRT()
39594076
trans2Plan->oDist = oDist;
39604077
}
39614078

4079+
void TreeNode::assign_params_CS_3D_RTRTRT()
4080+
{
4081+
assert(scheme == CS_3D_RTRTRT);
4082+
assert(childNodes.size() == 6);
4083+
// TODO:
4084+
// Need regular transpose padding to improve performance for cases that
4085+
// lengths are aligned to 64, i.e. 512x512x512. However, there are few
4086+
// potential issues need to be fixed first:
4087+
// (1) The performance of current transpose_kernel2_scheme for case
4088+
// 512x512x512 need to be improved.
4089+
// (2) For in-place transform, the user buffer is not big enough for
4090+
// output with padding.
4091+
// (3) We should be able to pad the output of the 1st and 2nd transpose,
4092+
// and naturally the input of the 3rd transpose. However, the current
4093+
// transpose_kernel2_scheme and transpose_tile_device don't work for
4094+
// the 2nd padding transpose.
4095+
// Or the perf of new diagonal transpose is good enough that we don't
4096+
// need padding any more.
4097+
4098+
for(int i = 0; i < 6; i += 2)
4099+
{
4100+
auto row_plan = childNodes[i];
4101+
if(i == 0)
4102+
{
4103+
row_plan->inStride = inStride;
4104+
row_plan->iDist = iDist;
4105+
row_plan->outStride = outStride;
4106+
row_plan->oDist = oDist;
4107+
}
4108+
else
4109+
{
4110+
row_plan->inStride = childNodes[i - 1]->outStride;
4111+
row_plan->iDist = childNodes[i - 1]->oDist;
4112+
row_plan->outStride = row_plan->inStride;
4113+
row_plan->oDist = row_plan->iDist;
4114+
}
4115+
row_plan->TraverseTreeAssignParamsLogicA();
4116+
4117+
auto trans_plan = childNodes[i + 1];
4118+
trans_plan->inStride = row_plan->outStride;
4119+
trans_plan->iDist = row_plan->oDist;
4120+
4121+
trans_plan->outStride.push_back(1);
4122+
trans_plan->outStride.push_back(trans_plan->outStride[0] * trans_plan->length[2]);
4123+
trans_plan->outStride.push_back(trans_plan->outStride[1] * trans_plan->length[0]);
4124+
trans_plan->oDist = trans_plan->iDist;
4125+
}
4126+
}
4127+
39624128
void TreeNode::assign_params_CS_3D_RC_STRAIGHT()
39634129
{
39644130
auto xyPlan = childNodes[0];

0 commit comments

Comments
 (0)