@@ -90,6 +90,7 @@ std::string PrintScheme(ComputeScheme cs)
9090 {ENUMSTR (CS_KERNEL_2D_SINGLE)},
9191
9292 {ENUMSTR (CS_3D_STRAIGHT)},
93+ {ENUMSTR (CS_3D_RTRTRT)},
9394 {ENUMSTR (CS_3D_RTRT)},
9495 {ENUMSTR (CS_3D_RC)},
9596 {ENUMSTR (CS_KERNEL_3D_STOCKHAM_BLOCK_CC)},
@@ -825,6 +826,24 @@ void TreeNode::RecursiveBuildTree()
825826 else
826827 {
827828 scheme = CS_3D_RTRT;
829+
830+ // NB:
831+ // Try to build the 1st child but not really add it in. Switch to
832+ // CS_3D_RTRTRT if the 1st child is CS_2D_RTRT.(Any better idea?)
833+ // And enable this only for cases that lengths are not aligned to
834+ // 64 because perf issue.
835+ // See more comments in assign_params_CS_3D_RTRTRT().
836+ TreeNode* child0 = TreeNode::CreateNode (this );
837+ child0->length = length;
838+ child0->dimension = 2 ;
839+ child0->RecursiveBuildTree ();
840+ if ((child0->scheme == CS_2D_RTRT) && (length[0 ] % 64 ) && (length[1 ] % 64 )
841+ && (length[2 ] % 64 ))
842+ {
843+ scheme = CS_3D_RTRTRT;
844+ }
845+
846+ DeleteNode (child0);
828847 }
829848
830849 switch (scheme)
@@ -834,6 +853,11 @@ void TreeNode::RecursiveBuildTree()
834853 build_CS_3D_RTRT ();
835854 }
836855 break ;
856+ case CS_3D_RTRTRT:
857+ {
858+ build_CS_3D_RTRTRT ();
859+ }
860+ break ;
837861 case CS_3D_RC:
838862 {
839863 // 2d fft
@@ -1927,6 +1951,32 @@ void TreeNode::build_CS_3D_RTRT()
19271951 childNodes.push_back (trans2Plan);
19281952}
19291953
1954+ void TreeNode::build_CS_3D_RTRTRT ()
1955+ {
1956+ scheme = CS_3D_RTRTRT;
1957+ std::vector<size_t > cur_length = length;
1958+
1959+ for (int i = 0 ; i < 6 ; i += 2 )
1960+ {
1961+ // row ffts
1962+ auto row_plan = TreeNode::CreateNode (this );
1963+ row_plan->length = cur_length;
1964+ row_plan->dimension = 1 ;
1965+ row_plan->RecursiveBuildTree ();
1966+ childNodes.push_back (row_plan);
1967+
1968+ // transpose XY_Z
1969+ auto trans_plan = TreeNode::CreateNode (this );
1970+ trans_plan->length = cur_length;
1971+ trans_plan->scheme = CS_KERNEL_TRANSPOSE_XY_Z;
1972+ trans_plan->dimension = 2 ;
1973+ childNodes.push_back (trans_plan);
1974+
1975+ std::swap (cur_length[2 ], cur_length[1 ]);
1976+ std::swap (cur_length[1 ], cur_length[0 ]);
1977+ }
1978+ }
1979+
19301980struct TreeNode ::TraverseState
19311981{
19321982 TraverseState (const ExecPlan& execPlan)
@@ -2086,6 +2136,9 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(TraverseState& state,
20862136 case CS_3D_RC:
20872137 assign_buffers_CS_RC (state, flipIn, flipOut, obOutBuf);
20882138 break ;
2139+ case CS_3D_RTRTRT:
2140+ assign_buffers_CS_3D_RTRTRT (state, flipIn, flipOut, obOutBuf);
2141+ break ;
20892142 default :
20902143 if (parent == nullptr )
20912144 {
@@ -2767,6 +2820,67 @@ void TreeNode::assign_buffers_CS_RC(TraverseState& state,
27672820 }
27682821}
27692822
2823+ void TreeNode::assign_buffers_CS_3D_RTRTRT (TraverseState& state,
2824+ OperatingBuffer& flipIn,
2825+ OperatingBuffer& flipOut,
2826+ OperatingBuffer& obOutBuf)
2827+ {
2828+ assert (scheme == CS_3D_RTRTRT);
2829+ assert (childNodes.size () == 6 );
2830+
2831+ obOut = obOutBuf;
2832+
2833+ // TODO: adjust buffer assignment for padding
2834+
2835+ flipIn = obIn;
2836+ flipOut = OB_TEMP;
2837+
2838+ // R
2839+ childNodes[0 ]->SetInputBuffer (state);
2840+ childNodes[0 ]->obOut = obOutBuf;
2841+ childNodes[0 ]->inArrayType = inArrayType;
2842+ childNodes[0 ]->outArrayType = outArrayType;
2843+ childNodes[0 ]->TraverseTreeAssignBuffersLogicA (state, flipIn, flipOut, obOutBuf);
2844+
2845+ flipIn = OB_TEMP;
2846+ flipOut = obOut;
2847+ obOutBuf = obOut;
2848+
2849+ // T
2850+ childNodes[1 ]->SetInputBuffer (state);
2851+ childNodes[1 ]->obOut = OB_TEMP;
2852+ childNodes[1 ]->inArrayType = childNodes[0 ]->outArrayType ;
2853+ childNodes[1 ]->outArrayType = rocfft_array_type_complex_interleaved;
2854+
2855+ // R
2856+ childNodes[2 ]->inArrayType = rocfft_array_type_complex_interleaved;
2857+ childNodes[2 ]->outArrayType = rocfft_array_type_complex_interleaved;
2858+ childNodes[2 ]->SetInputBuffer (state);
2859+ childNodes[2 ]->obOut = OB_TEMP;
2860+ flipIn = OB_TEMP;
2861+ flipOut = obOutBuf;
2862+ childNodes[2 ]->TraverseTreeAssignBuffersLogicA (state, flipIn, flipOut, obOutBuf);
2863+
2864+ // T
2865+ childNodes[3 ]->SetInputBuffer (state);
2866+ childNodes[3 ]->obOut = obOutBuf;
2867+ childNodes[3 ]->inArrayType = rocfft_array_type_complex_interleaved;
2868+ childNodes[3 ]->outArrayType = outArrayType;
2869+
2870+ // R
2871+ childNodes[4 ]->SetInputBuffer (state);
2872+ childNodes[4 ]->obOut = flipIn;
2873+ childNodes[4 ]->TraverseTreeAssignBuffersLogicA (state, flipIn, flipOut, obOutBuf);
2874+ childNodes[4 ]->inArrayType = childNodes[3 ]->outArrayType ;
2875+ childNodes[4 ]->outArrayType = rocfft_array_type_complex_interleaved;
2876+
2877+ // T
2878+ childNodes[5 ]->SetInputBuffer (state);
2879+ childNodes[5 ]->inArrayType = rocfft_array_type_complex_interleaved;
2880+ childNodes[5 ]->outArrayType = outArrayType;
2881+ childNodes[5 ]->obOut = obOutBuf;
2882+ }
2883+
27702884// /////////////////////////////////////////////////////////////////////////////
27712885// / Set placement variable and in/out array types, if not already set.
27722886void TreeNode::TraverseTreeAssignPlacementsLogicA (const rocfft_array_type rootIn,
@@ -2933,6 +3047,9 @@ void TreeNode::TraverseTreeAssignParamsLogicA()
29333047 case CS_3D_RTRT:
29343048 assign_params_CS_3D_RTRT ();
29353049 break ;
3050+ case CS_3D_RTRTRT:
3051+ assign_params_CS_3D_RTRTRT ();
3052+ break ;
29363053 case CS_3D_RC:
29373054 case CS_3D_STRAIGHT:
29383055 assign_params_CS_3D_RC_STRAIGHT ();
@@ -3959,6 +4076,55 @@ void TreeNode::assign_params_CS_3D_RTRT()
39594076 trans2Plan->oDist = oDist;
39604077}
39614078
4079+ void TreeNode::assign_params_CS_3D_RTRTRT ()
4080+ {
4081+ assert (scheme == CS_3D_RTRTRT);
4082+ assert (childNodes.size () == 6 );
4083+ // TODO:
4084+ // Need regular transpose padding to improve performance for cases that
4085+ // lengths are aligned to 64, i.e. 512x512x512. However, there are few
4086+ // potential issues need to be fixed first:
4087+ // (1) The performance of current transpose_kernel2_scheme for case
4088+ // 512x512x512 need to be improved.
4089+ // (2) For in-place transform, the user buffer is not big enough for
4090+ // output with padding.
4091+ // (3) We should be able to pad the output of the 1st and 2nd transpose,
4092+ // and naturally the input of the 3rd transpose. However, the current
4093+ // transpose_kernel2_scheme and transpose_tile_device don't work for
4094+ // the 2nd padding transpose.
4095+ // Or the perf of new diagonal transpose is good enough that we don't
4096+ // need padding any more.
4097+
4098+ for (int i = 0 ; i < 6 ; i += 2 )
4099+ {
4100+ auto row_plan = childNodes[i];
4101+ if (i == 0 )
4102+ {
4103+ row_plan->inStride = inStride;
4104+ row_plan->iDist = iDist;
4105+ row_plan->outStride = outStride;
4106+ row_plan->oDist = oDist;
4107+ }
4108+ else
4109+ {
4110+ row_plan->inStride = childNodes[i - 1 ]->outStride ;
4111+ row_plan->iDist = childNodes[i - 1 ]->oDist ;
4112+ row_plan->outStride = row_plan->inStride ;
4113+ row_plan->oDist = row_plan->iDist ;
4114+ }
4115+ row_plan->TraverseTreeAssignParamsLogicA ();
4116+
4117+ auto trans_plan = childNodes[i + 1 ];
4118+ trans_plan->inStride = row_plan->outStride ;
4119+ trans_plan->iDist = row_plan->oDist ;
4120+
4121+ trans_plan->outStride .push_back (1 );
4122+ trans_plan->outStride .push_back (trans_plan->outStride [0 ] * trans_plan->length [2 ]);
4123+ trans_plan->outStride .push_back (trans_plan->outStride [1 ] * trans_plan->length [0 ]);
4124+ trans_plan->oDist = trans_plan->iDist ;
4125+ }
4126+ }
4127+
39624128void TreeNode::assign_params_CS_3D_RC_STRAIGHT ()
39634129{
39644130 auto xyPlan = childNodes[0 ];
0 commit comments