Skip to content

Commit 0ecff19

Browse files
committed
[DPAS] Pick warpsPerCTA based on fast changing axis of A matrix
1 parent 37b841e commit 0ecff19

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ namespace {
3232
SmallVector<unsigned>
3333
getWarpsPerTile(tt::DotOp dotOp,
3434
ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
35-
const ArrayRef<int64_t> shape, unsigned numWarps) {
35+
const ArrayRef<int64_t> shape, unsigned numWarps,
36+
const SmallVector<unsigned> &order) {
37+
3638
auto filter = [&dotOp](Operation *op) {
3739
return op->getParentRegion() == dotOp->getParentRegion();
3840
};
@@ -64,7 +66,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6466
uint32_t colRowRatio =
6567
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
6668

67-
int rowDim = rank - 2, colDim = rank - 1;
69+
int rowDim = order[rank - 2], colDim = order[rank - 1];
6870
do {
6971
if (ret[rowDim] * ret[colDim] >= numWarps)
7072
break;
@@ -78,7 +80,6 @@ getWarpsPerTile(tt::DotOp dotOp,
7880
ret[colDim] *= 2;
7981
}
8082
} while (true);
81-
8283
return ret;
8384
}
8485

@@ -117,8 +118,22 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
117118
Type elemType = oldAType.getElementType();
118119
unsigned opsPerChan =
119120
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
121+
122+
SmallVector<unsigned> order = {0, 1};
123+
Operation *aOp = a.getDefiningOp();
124+
if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
125+
auto valueToConvert = aOp->getOperand(0);
126+
aOp = valueToConvert.getDefiningOp();
127+
}
128+
if (aOp && isa<tt::LoadOp>(aOp)) {
129+
assert(aOp->getNumResults() == 1);
130+
Attribute layout =
131+
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
132+
order = triton::gpu::getOrder(layout);
133+
}
134+
120135
SmallVector<unsigned> warpsPerTile =
121-
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
136+
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
122137
size_t rank = retShape.size();
123138
SmallVector<unsigned> repCluster(rank, 1);
124139

0 commit comments

Comments
 (0)