Skip to content

Commit 788c706

Browse files
committed
Use block loads for post-dpas vector computation 2/?
1 parent 5fa7f61 commit 788c706

File tree

1 file changed

+93
-170
lines changed

1 file changed

+93
-170
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 93 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,10 @@ struct LoadOpConversion
516516
"Only row_major or column_major is supported");
517517
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
518518

519-
auto dpasLayout = hasDpasLayout ? cast<DpasEncodingAttr>(encoding) : cast<DpasEncodingAttr>(getDotEncoding(tensorType).value().getParent());
519+
auto dpasLayout = hasDpasLayout
520+
? cast<DpasEncodingAttr>(encoding)
521+
: cast<DpasEncodingAttr>(
522+
getDotEncoding(tensorType).value().getParent());
520523
auto dotOrder = dpasLayout.getThreadOrder();
521524
size_t rank = dotOrder.size();
522525
const bool valueRowMajor =
@@ -553,79 +556,78 @@ struct LoadOpConversion
553556
SmallVector<Value> multiDimWarpId =
554557
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
555558

556-
#if 1
559+
// TODO: de-duplicate with above and below code
557560
if (hasDpasLayout) {
558-
llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
559-
#if 1
560-
MLIRContext *ctx = rewriter.getContext();
561-
562-
Type eltTy = tensorType.getElementType();
563-
llvm::errs() << "Element type: " << eltTy << "\n";
564-
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
565-
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
566-
567-
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
568-
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
569-
Type load2DGenXType =
570-
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
571-
elemsPerLane); // make it opaque type.
572-
llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
573-
574-
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
575-
offsetBaseY] =
576-
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
577-
baseWidth = trunc(i32_ty, baseWidth);
578-
baseHeight = trunc(i32_ty, baseHeight);
579-
580-
// always row order coming out of DPAS
581-
auto pitch = trunc(i32_ty, rowStride);
582-
583-
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
584-
unsigned outerDimWarpNum = std::min<unsigned>(
585-
warpsPerCTA[rank - 2],
586-
mlir::ceil<unsigned>(tensorShape[rank - 2], repClusterShape[rank - 2]));
587-
unsigned innerDimWarpNum = std::min<unsigned>(
588-
warpsPerCTA[rank - 1],
589-
mlir::ceil<unsigned>(tensorShape[rank - 1], repClusterShape[rank - 1]));
590-
Value outerDimWarpId =
591-
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
592-
Value innerDimWarpId =
593-
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
594-
int64_t numRepOuter = numReps[1];
595-
int64_t numRepInner = numReps[2];
596-
597-
598-
std::array<unsigned, 2> replicaStride = {
599-
outerDimWarpNum * repClusterShape[rank - 2],
600-
innerDimWarpNum * repClusterShape[rank - 1]};
601-
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
602-
repClusterShape[rank - 1]};
603-
604-
Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
605-
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
606-
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
607-
Value warpId1Offset = add(dimWarpId1, offsetBaseX);
608-
609-
llvm::errs() << "elemsPerInstr: " << elemsPerInstr[0] << ", " << elemsPerInstr[1] << "\n";
610-
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
611-
unsigned valOffset = 0;
612-
613-
SmallVector<Value> unpackedLoadedVals;
614-
615-
for (int m = 0; m < numRepOuter; ++m) {
616-
for (int n = 0; n < numRepInner; ++n) {
617-
for (int repM = 0; repM < repCluster[0]; ++repM) {
618-
619-
Value offsetY = add(warpId0Offset, i32_val(m * replicaStride[0] +
620-
repM * elemsPerInstr[0]));
621-
for (int repN = 0; repN < repCluster[1]; ++repN) {
622-
llvm::errs() << "m, n, repM, repN: " << m << ", " << n << ", " << repM << ", " << repN << "\n";
623-
Value offsetX =
624-
add(warpId1Offset,
625-
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
626-
627-
assert(!isTransposeRequired);
628-
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
561+
// llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
562+
MLIRContext *ctx = rewriter.getContext();
563+
564+
Type eltTy = tensorType.getElementType();
565+
// llvm::errs() << "Element type: " << eltTy << "\n";
566+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
567+
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
568+
569+
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
570+
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
571+
Type load2DGenXType =
572+
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
573+
elemsPerLane); // make it opaque type.
574+
// llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
575+
576+
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
577+
offsetBaseY] =
578+
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
579+
baseWidth = trunc(i32_ty, baseWidth);
580+
baseHeight = trunc(i32_ty, baseHeight);
581+
582+
// always row order coming out of DPAS
583+
auto pitch = trunc(i32_ty, rowStride);
584+
585+
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
586+
unsigned outerDimWarpNum =
587+
std::min<unsigned>(warpsPerCTA[rank - 2],
588+
mlir::ceil<unsigned>(tensorShape[rank - 2],
589+
repClusterShape[rank - 2]));
590+
unsigned innerDimWarpNum =
591+
std::min<unsigned>(warpsPerCTA[rank - 1],
592+
mlir::ceil<unsigned>(tensorShape[rank - 1],
593+
repClusterShape[rank - 1]));
594+
Value outerDimWarpId =
595+
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
596+
Value innerDimWarpId =
597+
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
598+
int64_t numRepOuter = numReps[1];
599+
int64_t numRepInner = numReps[2];
600+
601+
std::array<unsigned, 2> replicaStride = {
602+
outerDimWarpNum * repClusterShape[rank - 2],
603+
innerDimWarpNum * repClusterShape[rank - 1]};
604+
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
605+
repClusterShape[rank - 1]};
606+
607+
Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
608+
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
609+
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
610+
Value warpId1Offset = add(dimWarpId1, offsetBaseX);
611+
612+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
613+
unsigned valOffset = 0;
614+
615+
SmallVector<Value> unpackedLoadedVals;
616+
617+
for (int m = 0; m < numRepOuter; ++m) {
618+
for (int n = 0; n < numRepInner; ++n) {
619+
for (int repM = 0; repM < repCluster[0]; ++repM) {
620+
621+
Value offsetY =
622+
add(warpId0Offset,
623+
i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
624+
for (int repN = 0; repN < repCluster[1]; ++repN) {
625+
Value offsetX =
626+
add(warpId1Offset,
627+
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
628+
629+
assert(!isTransposeRequired);
630+
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
629631
loc, load2DGenXType,
630632
/*ptr*/ base,
631633
/*base_width*/ mul(baseWidth, elemSizeInBytes),
@@ -641,109 +643,33 @@ struct LoadOpConversion
641643
/*vnni_transform*/false /*
642644
(usePackedType && !isOperandA && !isTransposeRequired &&
643645
eltTy.getIntOrFloatBitWidth() != 32)*/);
644-
if (failed(load2dOp.verify())) {
645-
// Explicitly invoke verifier because `triton_gen` ops are
646-
// immediately lowered further to a builtin call.
647-
return failure();
648-
}
649-
650-
651-
#if 0
652-
llvm::errs() << "elemsPerLane: " << elemsPerLane << "\n";
653-
SmallVector<int32_t> indices(elemsPerLane);
654-
for (int elemIdx = 0; elemIdx < elemsPerLane;
655-
++elemIdx) {
656-
indices[elemIdx] = elemIdx * n;
646+
if (failed(load2dOp.verify())) {
647+
// Explicitly invoke verifier because `triton_gen` ops are
648+
// immediately lowered further to a builtin call.
649+
return failure();
657650
}
658651

659-
#if 0
660-
llvm::errs() << "indices: ";
661-
for (size_t i = 0; i < indices.size(); i++) {
662-
llvm::errs() << " " << i;
663-
}
664-
llvm::errs() << "\n";
665-
#endif
666-
DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices);
667-
Value loadVal = rewriter.create<LLVM::ShuffleVectorOp>(
668-
loc, load2DGenXType, load2dOp, load2dOp, attr);
669-
#endif
670-
Value ret = bitcast(load2dOp, LLVM::getFixedVectorType(eltTy,
671-
elemsPerLane));
672-
llvm::errs() << "ret: " << ret << "\n";
673-
// each load should give us one column
674-
for(size_t i = 0; i < elemsPerLane; i++) {
675-
Value loaded =
676-
extract_element(eltTy, ret, i32_val(i));
677-
unpackedLoadedVals.push_back(loaded);
678-
}
679-
680-
// loadVals[{outer * packedRowNum * numLoadPerOutRepCluster +
681-
// rep * packedRowNum + row,
682-
// k + vblk * packedColNumPerVBlock + col}] =
683-
// bitcast(loadVal, unpackedDPASOperandType);
684-
#if 0
685-
Value storeVal = rewriter.create<LLVM::UndefOp>(
686-
loc, LLVM::getFixedVectorType(typeConverter->convertType(eltTy),
687-
elemsPerLane));
688-
for (size_t i = 0; i < elemsPerLane; ++i) {
689-
storeVal = insert_element(storeVal, vals[valOffset], i32_val(i));
690-
++valOffset;
691-
}
692-
693-
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
694-
loc,
695-
/*ptr*/ base,
696-
/*base_width*/ baseWidth,
697-
/*base_height*/ height,
698-
/*base_pitch*/ basePitch,
699-
/*x*/ trunc(i32_ty, offsetX),
700-
/*y*/ trunc(i32_ty, offsetY),
701-
/*elem_size_in_bits*/ elemSizeInBits,
702-
/*tile_width*/ elemsPerInstr[1],
703-
/*tile_height*/ elemsPerInstr[0],
704-
/*v_blocks*/ 1,
705-
/*stored_val*/ bitcast(storeVal, store2DGenXType));
706-
707-
if (failed(newOp.verify())) {
708-
// Explicitly invoke verifier because `triton_gen` ops are
709-
// immediately lowered further to a builtin call.
710-
return failure();
652+
Value ret = bitcast(
653+
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));
654+
// llvm::errs() << "ret: " << ret << "\n";
655+
// each load should give us one column
656+
for (size_t i = 0; i < elemsPerLane; i++) {
657+
Value loaded = extract_element(eltTy, ret, i32_val(i));
658+
unpackedLoadedVals.push_back(loaded);
659+
}
711660
}
712-
#endif
713661
}
714662
}
715663
}
716-
}
717-
#if 0
718-
return failure();
719-
#else
720-
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
721-
Type llvmResultStructTy = typeConverter->convertType(op.getType());
722-
Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals,
723-
rewriter, llvmResultStructTy);
724-
rewriter.replaceOp(op, {resultStruct});
725-
726-
return success();
727-
#endif
728-
#else
729664

730-
ValueTable loadVals;
665+
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
666+
Type llvmResultStructTy = typeConverter->convertType(op.getType());
667+
Value resultStruct = packLLElements(
668+
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
669+
rewriter.replaceOp(op, {resultStruct});
731670

732-
unsigned numRepOuter = numReps[2];
733-
// TODO: calculate this instead of guessing
734-
numOperandsPer2DLoadM = 1;
735-
numOperandsPer2DloadN = 1;
736-
for (int outer = 0; outer < numRepOuter; ++outer) {
737-
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
738-
for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
739-
llvm::errs() << "load: " << outer << ", " << rep << ", " << k << "\n";
740-
}
741-
}
742-
}
743-
return failure();
744-
#endif
671+
return success();
745672
}
746-
#endif
747673

748674
bool isOperandA = (opIdx == 0);
749675
SmallVector<unsigned> dpasInstShape = isOperandA
@@ -788,7 +714,6 @@ struct LoadOpConversion
788714

789715
Type packedDPASOperandType = LLVM::getFixedVectorType(
790716
loadResultElemType, packedElemsPerLanePerDPASInst);
791-
llvm::errs() << "packed DPAS operand type: " << packedDPASOperandType << "\n";
792717

793718
// Outer dim: Dim M or N. Inner dim: Dim K.
794719
// Round the warp id fit into the tensor shape.
@@ -908,11 +833,9 @@ struct LoadOpConversion
908833
Value elemSizeInBytes = i32_val(originalElemBits / 8);
909834

910835
ValueTable loadVals;
911-
llvm::errs() << "Generating 2D block load for op: " << opIdx << "\n";
912836
for (int outer = 0; outer < numRepOuter; ++outer) {
913837
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
914838
for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
915-
llvm::errs() << "outer, rep, k = " << outer << ", " << rep << ", " << k << "\n";
916839
Value offsetX, offsetY;
917840
if (opIdx == 0) {
918841
// A

0 commit comments

Comments
 (0)