@@ -32,7 +32,9 @@ namespace {
3232SmallVector<unsigned >
3333getWarpsPerTile (tt::DotOp dotOp,
3434 ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
35- const ArrayRef<int64_t > shape, unsigned numWarps) {
35+ const ArrayRef<int64_t > shape, unsigned numWarps,
36+ const SmallVector<unsigned > &order) {
37+
3638 auto filter = [&dotOp](Operation *op) {
3739 return op->getParentRegion () == dotOp->getParentRegion ();
3840 };
@@ -64,7 +66,7 @@ getWarpsPerTile(tt::DotOp dotOp,
6466 uint32_t colRowRatio =
6567 ceil<uint32_t >(dpasCap.executionSize , dpasCap.repeatCount );
6668
67- int rowDim = rank - 2 , colDim = rank - 1 ;
69+ int rowDim = order[ rank - 2 ] , colDim = order[ rank - 1 ] ;
6870 do {
6971 if (ret[rowDim] * ret[colDim] >= numWarps)
7072 break ;
@@ -78,7 +80,6 @@ getWarpsPerTile(tt::DotOp dotOp,
7880 ret[colDim] *= 2 ;
7981 }
8082 } while (true );
81-
8283 return ret;
8384}
8485
@@ -117,8 +118,22 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
117118 Type elemType = oldAType.getElementType ();
118119 unsigned opsPerChan =
119120 ttg::intel::DpasEncodingAttr::getOpsPerChannel (elemType);
121+
122+ SmallVector<unsigned > order = {0 , 1 };
123+ Operation *aOp = a.getDefiningOp ();
124+ if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
125+ auto valueToConvert = aOp->getOperand (0 );
126+ aOp = valueToConvert.getDefiningOp ();
127+ }
128+ if (aOp && isa<tt::LoadOp>(aOp)) {
129+ assert (aOp->getNumResults () == 1 );
130+ Attribute layout =
131+ cast<RankedTensorType>(aOp->getResult (0 ).getType ()).getEncoding ();
132+ order = triton::gpu::getOrder (layout);
133+ }
134+
120135 SmallVector<unsigned > warpsPerTile =
121- getWarpsPerTile (dotOp, dpasCap, retShape, numWarps);
136+ getWarpsPerTile (dotOp, dpasCap, retShape, numWarps, order );
122137 size_t rank = retShape.size ();
123138 SmallVector<unsigned > repCluster (rank, 1 );
124139
0 commit comments