Skip to content

Commit 6e73ab6

Browse files
committed
[intel] improve pitch and width constexpr folding
1 parent 6f1525f commit 6e73ab6

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
136136
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
137137
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
138138
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
139-
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
139+
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_11]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
140140
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
141141
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
142142
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -199,7 +199,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
199199
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
200200
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
201201
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
202-
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
202+
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_10]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
203203
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
204204
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
205205
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,42 @@ static int __builtin_ctz(unsigned x) {
3939

4040
namespace {
4141

42+
static Value skipCasts(Value v) {
43+
Operation *def = v.getDefiningOp();
44+
if (def &&
45+
isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
46+
return def->getOperand(0);
47+
return v;
48+
}
49+
50+
static Value tryFoldOp(Value v) {
51+
if (Operation *def = v.getDefiningOp()) {
52+
SmallVector<OpFoldResult> results;
53+
if (succeeded(def->fold(results)) && results.size() == 1) {
54+
if (auto val = dyn_cast_or_null<Value>(results[0]))
55+
return val;
56+
}
57+
}
58+
return v;
59+
}
60+
61+
static std::optional<int64_t> tryConstEval(Value v, int depth = 16) {
62+
for (int i = 0; i < depth; ++i) {
63+
if (auto res = getConstantIntValue(v))
64+
return res;
65+
66+
Value newV = skipCasts(v);
67+
newV = tryFoldOp(newV);
68+
69+
if (newV == v)
70+
break;
71+
72+
v = newV;
73+
}
74+
75+
return std::nullopt;
76+
}
77+
4278
Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
4379
auto tb = TritonLLVMOpBuilder(loc, rewriter);
4480
if (a && b) {
@@ -1590,23 +1626,19 @@ struct LoadOpToBlockIOConversion
15901626
std::swap(baseWidth, baseHeight);
15911627
}
15921628
// HW requires the pitch to be at least 64 bytes.
1593-
std::function<Value(Value)> skipTrunc = [&](Value v) {
1594-
if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp()))
1595-
return skipTrunc(v.getDefiningOp()->getOperand(0));
1596-
return v;
1597-
};
1598-
if (Operation *op = skipTrunc(pitch).getDefiningOp()) {
1599-
std::optional<int64_t> pitchConst =
1600-
mlir::triton::intel::getFoldedConstantValue(op);
1601-
if (pitchConst.has_value()) {
1602-
if ((*pitchConst * elemSizeInBits / 8) < 64)
1603-
return failure();
1604-
}
1629+
if (auto pitchConst = tryConstEval(pitch)) {
1630+
if ((*pitchConst * elemSizeInBits / 8) < 64)
1631+
return failure();
16051632
}
16061633

16071634
baseWidth = b.trunc(i32_ty, baseWidth);
16081635
baseHeight = b.trunc(i32_ty, baseHeight);
16091636

1637+
if (auto widthConst = tryConstEval(baseWidth)) {
1638+
if ((*widthConst * elemSizeInBits / 8) < 64)
1639+
return failure();
1640+
}
1641+
16101642
const unsigned originalElemBits = elemSizeInBits;
16111643
if (isTransposeRequired) {
16121644
// adjust the block io parameter to align HW's limitations on

0 commit comments

Comments
 (0)