Skip to content

Commit 5532f2a

Browse files
committed
Use block loads for post-dpas vector computation 3/?
1 parent 788c706 commit 5532f2a

File tree

1 file changed

+24
-27
lines changed

1 file changed

+24
-27
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,6 @@ struct LoadOpConversion
529529
"Only row_major or column_major is allowed");
530530
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
531531

532-
// auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
533532
auto getOpIdx = [&]() -> unsigned {
534533
if (hasDpasLayout) {
535534
return 2;
@@ -541,6 +540,8 @@ struct LoadOpConversion
541540

542541
const unsigned opIdx = getOpIdx();
543542
Type eltTy = tensorType.getElementType();
543+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
544+
544545
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
545546
unsigned numElems = getTotalElemsPerThread(resultType);
546547
SmallVector<int64_t> numReps =
@@ -556,30 +557,30 @@ struct LoadOpConversion
556557
SmallVector<Value> multiDimWarpId =
557558
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
558559

559-
// TODO: de-duplicate with above and below code
560560
if (hasDpasLayout) {
561-
// llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
561+
if (isTransposeRequired) {
562+
// TODO: this would likely require a shuffle to match the expected
563+
// ordering coming out of the DPAS layout and requires more
564+
// investigation
565+
return failure();
566+
}
567+
562568
MLIRContext *ctx = rewriter.getContext();
563569

564-
Type eltTy = tensorType.getElementType();
565-
// llvm::errs() << "Element type: " << eltTy << "\n";
566-
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
567570
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
568571

569572
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
570573
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
571574
Type load2DGenXType =
572575
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
573576
elemsPerLane); // make it opaque type.
574-
// llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
575577

576578
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
577579
offsetBaseY] =
578580
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
579581
baseWidth = trunc(i32_ty, baseWidth);
580582
baseHeight = trunc(i32_ty, baseHeight);
581583

582-
// always row order coming out of DPAS
583584
auto pitch = trunc(i32_ty, rowStride);
584585

585586
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
@@ -626,7 +627,6 @@ struct LoadOpConversion
626627
add(warpId1Offset,
627628
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
628629

629-
assert(!isTransposeRequired);
630630
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
631631
loc, load2DGenXType,
632632
/*ptr*/ base,
@@ -636,13 +636,11 @@ struct LoadOpConversion
636636
/*x*/ trunc(i32_ty, offsetX),
637637
/*y*/ trunc(i32_ty, offsetY),
638638
/*elem_size_in_bits*/ elemSizeInBits,
639-
/*tile_width*/ elemsPerInstr[1],
640-
/*tile_height*/ elemsPerInstr[0],
639+
/*tile_width*/ elemsPerInstr[1],
640+
/*tile_height*/ elemsPerInstr[0],
641641
/*v_blocks*/ 1,
642-
/*transpose*/ isTransposeRequired,
643-
/*vnni_transform*/false /*
644-
(usePackedType && !isOperandA && !isTransposeRequired &&
645-
eltTy.getIntOrFloatBitWidth() != 32)*/);
642+
/*transpose*/ false,
643+
/*vnni_transform*/ false);
646644
if (failed(load2dOp.verify())) {
647645
// Explicitly invoke verifier because `triton_gen` ops are
648646
// immediately lowered further to a builtin call.
@@ -651,8 +649,7 @@ struct LoadOpConversion
651649

652650
Value ret = bitcast(
653651
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));
654-
// llvm::errs() << "ret: " << ret << "\n";
655-
// each load should give us one column
652+
656653
for (size_t i = 0; i < elemsPerLane; i++) {
657654
Value loaded = extract_element(eltTy, ret, i32_val(i));
658655
unpackedLoadedVals.push_back(loaded);
@@ -701,11 +698,11 @@ struct LoadOpConversion
701698
// input operands to DPAS.
702699
// TODO: add support for int4 and int2.
703700
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
704-
unsigned elemBits = eltTy.getIntOrFloatBitWidth();
705-
if ((opsPerChannel == 4 && elemBits == 8) ||
706-
(opsPerChannel == 2 && elemBits == 16) ||
707-
(opsPerChannel == 1 && elemBits == 32)) {
708-
loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty;
701+
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
702+
(opsPerChannel == 2 && elemSizeInBits == 16) ||
703+
(opsPerChannel == 1 && elemSizeInBits == 32)) {
704+
loadResultElemType =
705+
(isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty;
709706
packedElemsPerLanePerDPASInst =
710707
isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1)
711708
: elemsPerLanePerDPASInst / opsPerChannel;
@@ -779,7 +776,7 @@ struct LoadOpConversion
779776

780777
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
781778
// by enlarging the vBlocks.
782-
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8;
779+
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
783780
numOperandsPer2DloadN =
784781
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
785782
vBlocks = numOperandsPer2DloadN;
@@ -823,12 +820,12 @@ struct LoadOpConversion
823820
baseWidth = trunc(i32_ty, baseWidth);
824821
baseHeight = trunc(i32_ty, baseHeight);
825822

826-
unsigned originalElemBits = elemBits;
823+
const unsigned originalElemBits = elemSizeInBits;
827824
if (isTransposeRequired) {
828825
// adjust the block io parameter to align HW's limitations on
829826
// transposing load.
830827
tileWidth = tileWidth / (32 / originalElemBits);
831-
elemBits = 32;
828+
elemSizeInBits = 32;
832829
}
833830
Value elemSizeInBytes = i32_val(originalElemBits / 8);
834831

@@ -872,14 +869,14 @@ struct LoadOpConversion
872869
/*base_pitch*/ mul(pitch, elemSizeInBytes),
873870
/*x*/ trunc(i32_ty, offsetX),
874871
/*y*/ trunc(i32_ty, offsetY),
875-
/*elem_size_in_bits*/ elemBits,
872+
/*elem_size_in_bits*/ elemSizeInBits,
876873
/*tile_width*/ tileWidth,
877874
/*tile_height*/ tileHeight,
878875
/*v_blocks*/ vBlocks,
879876
/*transpose*/ isTransposeRequired,
880877
/*vnni_transform*/
881878
(usePackedType && !isOperandA && !isTransposeRequired &&
882-
eltTy.getIntOrFloatBitWidth() != 32));
879+
originalElemBits != 32));
883880
if (failed(load2dOp.verify())) {
884881
// Explicitly invoke verifier because `triton_gen` ops are
885882
// immediately lowered further to a builtin call.

0 commit comments

Comments
 (0)