Skip to content

Commit 5fa7f61

Browse files
committed
Use block loads for post-dpas vector computation 1/?
1 parent b70c7f7 commit 5fa7f61

File tree

1 file changed

+210
-5
lines changed

1 file changed

+210
-5
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 210 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ struct LoadOpConversion
499499
auto tensorType = cast<RankedTensorType>(resultType);
500500

501501
// Only lower loadOp with dpas layout encoding.
502-
if (!hasDotDpasEncoding(tensorType))
502+
auto encoding = tensorType.getEncoding();
503+
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
503505
return failure();
504506

505507
Attribute blockIOAttr =
@@ -514,8 +516,8 @@ struct LoadOpConversion
514516
"Only row_major or column_major is supported");
515517
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
516518

517-
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
518-
auto dotOrder = dotLayout.getThreadOrder();
519+
auto dpasLayout = hasDpasLayout ? cast<DpasEncodingAttr>(encoding) : cast<DpasEncodingAttr>(getDotEncoding(tensorType).value().getParent());
520+
auto dotOrder = dpasLayout.getThreadOrder();
519521
size_t rank = dotOrder.size();
520522
const bool valueRowMajor =
521523
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
@@ -524,9 +526,17 @@ struct LoadOpConversion
524526
"Only row_major or column_major is allowed");
525527
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526528

527-
auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
529+
// auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
530+
auto getOpIdx = [&]() -> unsigned {
531+
if (hasDpasLayout) {
532+
return 2;
533+
} else {
534+
auto dotLayout = getDotEncoding(tensorType).value();
535+
return dotLayout.getOpIdx();
536+
}
537+
};
528538

529-
const unsigned opIdx = dotLayout.getOpIdx();
539+
const unsigned opIdx = getOpIdx();
530540
Type eltTy = tensorType.getElementType();
531541
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
532542
unsigned numElems = getTotalElemsPerThread(resultType);
@@ -543,6 +553,198 @@ struct LoadOpConversion
543553
SmallVector<Value> multiDimWarpId =
544554
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545555

556+
#if 1
557+
if (hasDpasLayout) {
558+
llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
559+
#if 1
560+
MLIRContext *ctx = rewriter.getContext();
561+
562+
Type eltTy = tensorType.getElementType();
563+
llvm::errs() << "Element type: " << eltTy << "\n";
564+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
565+
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
566+
567+
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
568+
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
569+
Type load2DGenXType =
570+
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
571+
elemsPerLane); // make it opaque type.
572+
llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
573+
574+
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
575+
offsetBaseY] =
576+
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
577+
baseWidth = trunc(i32_ty, baseWidth);
578+
baseHeight = trunc(i32_ty, baseHeight);
579+
580+
// always row order coming out of DPAS
581+
auto pitch = trunc(i32_ty, rowStride);
582+
583+
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
584+
unsigned outerDimWarpNum = std::min<unsigned>(
585+
warpsPerCTA[rank - 2],
586+
mlir::ceil<unsigned>(tensorShape[rank - 2], repClusterShape[rank - 2]));
587+
unsigned innerDimWarpNum = std::min<unsigned>(
588+
warpsPerCTA[rank - 1],
589+
mlir::ceil<unsigned>(tensorShape[rank - 1], repClusterShape[rank - 1]));
590+
Value outerDimWarpId =
591+
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
592+
Value innerDimWarpId =
593+
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
594+
int64_t numRepOuter = numReps[1];
595+
int64_t numRepInner = numReps[2];
596+
597+
598+
std::array<unsigned, 2> replicaStride = {
599+
outerDimWarpNum * repClusterShape[rank - 2],
600+
innerDimWarpNum * repClusterShape[rank - 1]};
601+
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
602+
repClusterShape[rank - 1]};
603+
604+
Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
605+
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
606+
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
607+
Value warpId1Offset = add(dimWarpId1, offsetBaseX);
608+
609+
llvm::errs() << "elemsPerInstr: " << elemsPerInstr[0] << ", " << elemsPerInstr[1] << "\n";
610+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
611+
unsigned valOffset = 0;
612+
613+
SmallVector<Value> unpackedLoadedVals;
614+
615+
for (int m = 0; m < numRepOuter; ++m) {
616+
for (int n = 0; n < numRepInner; ++n) {
617+
for (int repM = 0; repM < repCluster[0]; ++repM) {
618+
619+
Value offsetY = add(warpId0Offset, i32_val(m * replicaStride[0] +
620+
repM * elemsPerInstr[0]));
621+
for (int repN = 0; repN < repCluster[1]; ++repN) {
622+
llvm::errs() << "m, n, repM, repN: " << m << ", " << n << ", " << repM << ", " << repN << "\n";
623+
Value offsetX =
624+
add(warpId1Offset,
625+
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
626+
627+
assert(!isTransposeRequired);
628+
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
629+
loc, load2DGenXType,
630+
/*ptr*/ base,
631+
/*base_width*/ mul(baseWidth, elemSizeInBytes),
632+
/*base_height*/ baseHeight,
633+
/*base_pitch*/ mul(pitch, elemSizeInBytes),
634+
/*x*/ trunc(i32_ty, offsetX),
635+
/*y*/ trunc(i32_ty, offsetY),
636+
/*elem_size_in_bits*/ elemSizeInBits,
637+
/*tile_width*/ elemsPerInstr[1],
638+
/*tile_height*/ elemsPerInstr[0],
639+
/*v_blocks*/ 1,
640+
/*transpose*/ isTransposeRequired,
641+
/*vnni_transform*/false /*
642+
(usePackedType && !isOperandA && !isTransposeRequired &&
643+
eltTy.getIntOrFloatBitWidth() != 32)*/);
644+
if (failed(load2dOp.verify())) {
645+
// Explicitly invoke verifier because `triton_gen` ops are
646+
// immediately lowered further to a builtin call.
647+
return failure();
648+
}
649+
650+
651+
#if 0
652+
llvm::errs() << "elemsPerLane: " << elemsPerLane << "\n";
653+
SmallVector<int32_t> indices(elemsPerLane);
654+
for (int elemIdx = 0; elemIdx < elemsPerLane;
655+
++elemIdx) {
656+
indices[elemIdx] = elemIdx * n;
657+
}
658+
659+
#if 0
660+
llvm::errs() << "indices: ";
661+
for (size_t i = 0; i < indices.size(); i++) {
662+
llvm::errs() << " " << i;
663+
}
664+
llvm::errs() << "\n";
665+
#endif
666+
DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices);
667+
Value loadVal = rewriter.create<LLVM::ShuffleVectorOp>(
668+
loc, load2DGenXType, load2dOp, load2dOp, attr);
669+
#endif
670+
Value ret = bitcast(load2dOp, LLVM::getFixedVectorType(eltTy,
671+
elemsPerLane));
672+
llvm::errs() << "ret: " << ret << "\n";
673+
// each load should give us one column
674+
for(size_t i = 0; i < elemsPerLane; i++) {
675+
Value loaded =
676+
extract_element(eltTy, ret, i32_val(i));
677+
unpackedLoadedVals.push_back(loaded);
678+
}
679+
680+
// loadVals[{outer * packedRowNum * numLoadPerOutRepCluster +
681+
// rep * packedRowNum + row,
682+
// k + vblk * packedColNumPerVBlock + col}] =
683+
// bitcast(loadVal, unpackedDPASOperandType);
684+
#if 0
685+
Value storeVal = rewriter.create<LLVM::UndefOp>(
686+
loc, LLVM::getFixedVectorType(typeConverter->convertType(eltTy),
687+
elemsPerLane));
688+
for (size_t i = 0; i < elemsPerLane; ++i) {
689+
storeVal = insert_element(storeVal, vals[valOffset], i32_val(i));
690+
++valOffset;
691+
}
692+
693+
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
694+
loc,
695+
/*ptr*/ base,
696+
/*base_width*/ baseWidth,
697+
/*base_height*/ height,
698+
/*base_pitch*/ basePitch,
699+
/*x*/ trunc(i32_ty, offsetX),
700+
/*y*/ trunc(i32_ty, offsetY),
701+
/*elem_size_in_bits*/ elemSizeInBits,
702+
/*tile_width*/ elemsPerInstr[1],
703+
/*tile_height*/ elemsPerInstr[0],
704+
/*v_blocks*/ 1,
705+
/*stored_val*/ bitcast(storeVal, store2DGenXType));
706+
707+
if (failed(newOp.verify())) {
708+
// Explicitly invoke verifier because `triton_gen` ops are
709+
// immediately lowered further to a builtin call.
710+
return failure();
711+
}
712+
#endif
713+
}
714+
}
715+
}
716+
}
717+
#if 0
718+
return failure();
719+
#else
720+
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
721+
Type llvmResultStructTy = typeConverter->convertType(op.getType());
722+
Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals,
723+
rewriter, llvmResultStructTy);
724+
rewriter.replaceOp(op, {resultStruct});
725+
726+
return success();
727+
#endif
728+
#else
729+
730+
ValueTable loadVals;
731+
732+
unsigned numRepOuter = numReps[2];
733+
// TODO: calculate this instead of guessing
734+
numOperandsPer2DLoadM = 1;
735+
numOperandsPer2DloadN = 1;
736+
for (int outer = 0; outer < numRepOuter; ++outer) {
737+
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
738+
for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
739+
llvm::errs() << "load: " << outer << ", " << rep << ", " << k << "\n";
740+
}
741+
}
742+
}
743+
return failure();
744+
#endif
745+
}
746+
#endif
747+
546748
bool isOperandA = (opIdx == 0);
547749
SmallVector<unsigned> dpasInstShape = isOperandA
548750
? dpasLayout.getDPASInstShapeA()
@@ -586,6 +788,7 @@ struct LoadOpConversion
586788

587789
Type packedDPASOperandType = LLVM::getFixedVectorType(
588790
loadResultElemType, packedElemsPerLanePerDPASInst);
791+
llvm::errs() << "packed DPAS operand type: " << packedDPASOperandType << "\n";
589792

590793
// Outer dim: Dim M or N. Inner dim: Dim K.
591794
// Round the warp id fit into the tensor shape.
@@ -705,9 +908,11 @@ struct LoadOpConversion
705908
Value elemSizeInBytes = i32_val(originalElemBits / 8);
706909

707910
ValueTable loadVals;
911+
llvm::errs() << "Generating 2D block load for op: " << opIdx << "\n";
708912
for (int outer = 0; outer < numRepOuter; ++outer) {
709913
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
710914
for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
915+
llvm::errs() << "outer, rep, k = " << outer << ", " << rep << ", " << k << "\n";
711916
Value offsetX, offsetY;
712917
if (opIdx == 0) {
713918
// A

0 commit comments

Comments
 (0)