Skip to content

Commit d3c798c

Browse files
Revert "Revert "Allow build_CS_3D_BLOCK_RC to also have composite sub-schemes. ""
This reverts commit 4341e77.
1 parent 4341e77 commit d3c798c

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

clients/tests/accuracy_test_adhoc.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ std::vector<std::vector<size_t>> adhoc_sizes = {
3636
// SBRC 192 with special param
3737
{192, 192, 192},
3838
{192, 84, 84},
39+
40+
// Failure with build_CS_3D_BLOCK_RC
41+
{680, 128, 128},
3942
};
4043

4144
const static std::vector<std::vector<size_t>> stride_range = {{1}};

library/src/assignment_policy.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,14 @@ void AssignmentPolicy::PadPlan(ExecPlan& execPlan)
11031103
// SBCR plans combine higher dimensions in ways that confuse padding
11041104
if(u.node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CR)
11051105
return;
1106+
// transpose kernels don't handle arbitrary strides,
1107+
// and with 4 or more lengths either choice of
1108+
// padding dim will trigger incorrect behaviour
1109+
if((u.node.scheme == CS_KERNEL_TRANSPOSE
1110+
|| u.node.scheme == CS_KERNEL_TRANSPOSE_XY_Z
1111+
|| u.node.scheme == CS_KERNEL_TRANSPOSE_Z_XY)
1112+
&& u.node.length.size() > 3)
1113+
return;
11061114
}
11071115

11081116
// Ensure that if we're forced to pad along one dimension

library/src/tree_node_3D.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,6 @@ void BLOCKRC3DNode::AssignParams_internal()
473473
node->oDist = node->outStride[2] * node->length[0];
474474
break;
475475
}
476-
case CS_KERNEL_STOCKHAM:
477-
{
478-
node->outStride = node->inStride;
479-
node->oDist = node->iDist;
480-
node->AssignParams();
481-
break;
482-
}
483476
case CS_KERNEL_TRANSPOSE_XY_Z:
484477
{
485478
node->outStride.push_back(1);
@@ -497,8 +490,12 @@ void BLOCKRC3DNode::AssignParams_internal()
497490
break;
498491
}
499492
default:
500-
// build_CS_3D_BLOCK_RC should not have created any other node types
501-
throw std::runtime_error("Scheme Assertion Failed, unexpected node scheme.");
493+
{
494+
node->outStride = node->inStride;
495+
node->oDist = node->iDist;
496+
node->AssignParams();
497+
break;
498+
}
502499
}
503500
prev_outStride = node->outStride;
504501
prev_oDist = node->oDist;

0 commit comments

Comments
 (0)