@@ -516,7 +516,10 @@ struct LoadOpConversion
516516 " Only row_major or column_major is supported" );
517517 const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
518518
519- auto dpasLayout = hasDpasLayout ? cast<DpasEncodingAttr>(encoding) : cast<DpasEncodingAttr>(getDotEncoding (tensorType).value ().getParent ());
519+ auto dpasLayout = hasDpasLayout
520+ ? cast<DpasEncodingAttr>(encoding)
521+ : cast<DpasEncodingAttr>(
522+ getDotEncoding (tensorType).value ().getParent ());
520523 auto dotOrder = dpasLayout.getThreadOrder ();
521524 size_t rank = dotOrder.size ();
522525 const bool valueRowMajor =
@@ -553,79 +556,78 @@ struct LoadOpConversion
553556 SmallVector<Value> multiDimWarpId =
554557 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder);
555558
556- # if 1
559+ // TODO: de-duplicate with above and below code
557560 if (hasDpasLayout) {
558- llvm::errs () << " rewriting tensor pointer load for dpas user!\n " ;
559- #if 1
560- MLIRContext *ctx = rewriter.getContext ();
561-
562- Type eltTy = tensorType.getElementType ();
563- llvm::errs () << " Element type: " << eltTy << " \n " ;
564- unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
565- Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
566-
567- SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
568- int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
569- Type load2DGenXType =
570- LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
571- elemsPerLane); // make it opaque type.
572- llvm::errs () << " load 2d gen x type: " << load2DGenXType << " \n " ;
573-
574- auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
575- offsetBaseY] =
576- getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
577- baseWidth = trunc (i32_ty, baseWidth);
578- baseHeight = trunc (i32_ty, baseHeight);
579-
580- // always row order coming out of DPAS
581- auto pitch = trunc (i32_ty, rowStride);
582-
583- SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
584- unsigned outerDimWarpNum = std::min<unsigned >(
585- warpsPerCTA[rank - 2 ],
586- mlir::ceil<unsigned >(tensorShape[rank - 2 ], repClusterShape[rank - 2 ]));
587- unsigned innerDimWarpNum = std::min<unsigned >(
588- warpsPerCTA[rank - 1 ],
589- mlir::ceil<unsigned >(tensorShape[rank - 1 ], repClusterShape[rank - 1 ]));
590- Value outerDimWarpId =
591- urem (multiDimWarpId[rank - 2 ], i32_val (outerDimWarpNum));
592- Value innerDimWarpId =
593- urem (multiDimWarpId[rank - 1 ], i32_val (innerDimWarpNum));
594- int64_t numRepOuter = numReps[1 ];
595- int64_t numRepInner = numReps[2 ];
596-
597-
598- std::array<unsigned , 2 > replicaStride = {
599- outerDimWarpNum * repClusterShape[rank - 2 ],
600- innerDimWarpNum * repClusterShape[rank - 1 ]};
601- std::array<unsigned , 2 > warpStride = {repClusterShape[rank - 2 ],
602- repClusterShape[rank - 1 ]};
603-
604- Value dimWarpId0 = mul (outerDimWarpId, i32_val (warpStride[0 ]));
605- Value dimWarpId1 = mul (innerDimWarpId, i32_val (warpStride[1 ]));
606- Value warpId0Offset = add (dimWarpId0, offsetBaseY);
607- Value warpId1Offset = add (dimWarpId1, offsetBaseX);
608-
609- llvm::errs () << " elemsPerInstr: " << elemsPerInstr[0 ] << " , " << elemsPerInstr[1 ] << " \n " ;
610- ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
611- unsigned valOffset = 0 ;
612-
613- SmallVector<Value> unpackedLoadedVals;
614-
615- for (int m = 0 ; m < numRepOuter; ++m) {
616- for (int n = 0 ; n < numRepInner; ++n) {
617- for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
618-
619- Value offsetY = add (warpId0Offset, i32_val (m * replicaStride[0 ] +
620- repM * elemsPerInstr[0 ]));
621- for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
622- llvm::errs () << " m, n, repM, repN: " << m << " , " << n << " , " << repM << " , " << repN << " \n " ;
623- Value offsetX =
624- add (warpId1Offset,
625- i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
626-
627- assert (!isTransposeRequired);
628- auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
561+ // llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
562+ MLIRContext *ctx = rewriter.getContext ();
563+
564+ Type eltTy = tensorType.getElementType ();
565+ // llvm::errs() << "Element type: " << eltTy << "\n";
566+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
567+ Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
568+
569+ SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
570+ int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
571+ Type load2DGenXType =
572+ LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
573+ elemsPerLane); // make it opaque type.
574+ // llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
575+
576+ auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
577+ offsetBaseY] =
578+ getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
579+ baseWidth = trunc (i32_ty, baseWidth);
580+ baseHeight = trunc (i32_ty, baseHeight);
581+
582+ // always row order coming out of DPAS
583+ auto pitch = trunc (i32_ty, rowStride);
584+
585+ SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
586+ unsigned outerDimWarpNum =
587+ std::min<unsigned >(warpsPerCTA[rank - 2 ],
588+ mlir::ceil<unsigned >(tensorShape[rank - 2 ],
589+ repClusterShape[rank - 2 ]));
590+ unsigned innerDimWarpNum =
591+ std::min<unsigned >(warpsPerCTA[rank - 1 ],
592+ mlir::ceil<unsigned >(tensorShape[rank - 1 ],
593+ repClusterShape[rank - 1 ]));
594+ Value outerDimWarpId =
595+ urem (multiDimWarpId[rank - 2 ], i32_val (outerDimWarpNum));
596+ Value innerDimWarpId =
597+ urem (multiDimWarpId[rank - 1 ], i32_val (innerDimWarpNum));
598+ int64_t numRepOuter = numReps[1 ];
599+ int64_t numRepInner = numReps[2 ];
600+
601+ std::array<unsigned , 2 > replicaStride = {
602+ outerDimWarpNum * repClusterShape[rank - 2 ],
603+ innerDimWarpNum * repClusterShape[rank - 1 ]};
604+ std::array<unsigned , 2 > warpStride = {repClusterShape[rank - 2 ],
605+ repClusterShape[rank - 1 ]};
606+
607+ Value dimWarpId0 = mul (outerDimWarpId, i32_val (warpStride[0 ]));
608+ Value dimWarpId1 = mul (innerDimWarpId, i32_val (warpStride[1 ]));
609+ Value warpId0Offset = add (dimWarpId0, offsetBaseY);
610+ Value warpId1Offset = add (dimWarpId1, offsetBaseX);
611+
612+ ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
613+ unsigned valOffset = 0 ;
614+
615+ SmallVector<Value> unpackedLoadedVals;
616+
617+ for (int m = 0 ; m < numRepOuter; ++m) {
618+ for (int n = 0 ; n < numRepInner; ++n) {
619+ for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
620+
621+ Value offsetY =
622+ add (warpId0Offset,
623+ i32_val (m * replicaStride[0 ] + repM * elemsPerInstr[0 ]));
624+ for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
625+ Value offsetX =
626+ add (warpId1Offset,
627+ i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
628+
629+ assert (!isTransposeRequired);
630+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
629631 loc, load2DGenXType,
630632 /* ptr*/ base,
631633 /* base_width*/ mul (baseWidth, elemSizeInBytes),
@@ -641,109 +643,33 @@ struct LoadOpConversion
641643 /* vnni_transform*/ false /*
642644 (usePackedType && !isOperandA && !isTransposeRequired &&
643645 eltTy.getIntOrFloatBitWidth() != 32)*/ );
644- if (failed (load2dOp.verify ())) {
645- // Explicitly invoke verifier because `triton_gen` ops are
646- // immediately lowered further to a builtin call.
647- return failure ();
648- }
649-
650-
651- #if 0
652- llvm::errs() << "elemsPerLane: " << elemsPerLane << "\n";
653- SmallVector<int32_t> indices(elemsPerLane);
654- for (int elemIdx = 0; elemIdx < elemsPerLane;
655- ++elemIdx) {
656- indices[elemIdx] = elemIdx * n;
646+ if (failed (load2dOp.verify ())) {
647+ // Explicitly invoke verifier because `triton_gen` ops are
648+ // immediately lowered further to a builtin call.
649+ return failure ();
657650 }
658651
659- #if 0
660- llvm::errs() << "indices: ";
661- for (size_t i = 0; i < indices.size(); i++) {
662- llvm::errs() << " " << i;
663- }
664- llvm::errs() << "\n";
665- #endif
666- DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices);
667- Value loadVal = rewriter.create<LLVM::ShuffleVectorOp>(
668- loc, load2DGenXType, load2dOp, load2dOp, attr);
669- #endif
670- Value ret = bitcast (load2dOp, LLVM::getFixedVectorType (eltTy,
671- elemsPerLane));
672- llvm::errs () << " ret: " << ret << " \n " ;
673- // each load should give us one column
674- for (size_t i = 0 ; i < elemsPerLane; i++) {
675- Value loaded =
676- extract_element (eltTy, ret, i32_val (i));
677- unpackedLoadedVals.push_back (loaded);
678- }
679-
680- // loadVals[{outer * packedRowNum * numLoadPerOutRepCluster +
681- // rep * packedRowNum + row,
682- // k + vblk * packedColNumPerVBlock + col}] =
683- // bitcast(loadVal, unpackedDPASOperandType);
684- #if 0
685- Value storeVal = rewriter.create<LLVM::UndefOp>(
686- loc, LLVM::getFixedVectorType(typeConverter->convertType(eltTy),
687- elemsPerLane));
688- for (size_t i = 0; i < elemsPerLane; ++i) {
689- storeVal = insert_element(storeVal, vals[valOffset], i32_val(i));
690- ++valOffset;
691- }
692-
693- auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
694- loc,
695- /*ptr*/ base,
696- /*base_width*/ baseWidth,
697- /*base_height*/ height,
698- /*base_pitch*/ basePitch,
699- /*x*/ trunc(i32_ty, offsetX),
700- /*y*/ trunc(i32_ty, offsetY),
701- /*elem_size_in_bits*/ elemSizeInBits,
702- /*tile_width*/ elemsPerInstr[1],
703- /*tile_height*/ elemsPerInstr[0],
704- /*v_blocks*/ 1,
705- /*stored_val*/ bitcast(storeVal, store2DGenXType));
706-
707- if (failed(newOp.verify())) {
708- // Explicitly invoke verifier because `triton_gen` ops are
709- // immediately lowered further to a builtin call.
710- return failure();
652+ Value ret = bitcast (
653+ load2dOp, LLVM::getFixedVectorType (eltTy, elemsPerLane));
654+ // llvm::errs() << "ret: " << ret << "\n";
655+ // each load should give us one column
656+ for (size_t i = 0 ; i < elemsPerLane; i++) {
657+ Value loaded = extract_element (eltTy, ret, i32_val (i));
658+ unpackedLoadedVals.push_back (loaded);
659+ }
711660 }
712- #endif
713661 }
714662 }
715663 }
716- }
717- #if 0
718- return failure();
719- #else
720- TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
721- Type llvmResultStructTy = typeConverter->convertType (op.getType ());
722- Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
723- rewriter, llvmResultStructTy);
724- rewriter.replaceOp (op, {resultStruct});
725-
726- return success ();
727- #endif
728- #else
729664
730- ValueTable loadVals;
665+ TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
666+ Type llvmResultStructTy = typeConverter->convertType (op.getType ());
667+ Value resultStruct = packLLElements (
668+ loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
669+ rewriter.replaceOp (op, {resultStruct});
731670
732- unsigned numRepOuter = numReps[2];
733- // TODO: calculate this instead of guessing
734- numOperandsPer2DLoadM = 1;
735- numOperandsPer2DloadN = 1;
736- for (int outer = 0; outer < numRepOuter; ++outer) {
737- for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
738- for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
739- llvm::errs() << "load: " << outer << ", " << rep << ", " << k << "\n";
740- }
741- }
742- }
743- return failure();
744- #endif
671+ return success ();
745672 }
746- #endif
747673
748674 bool isOperandA = (opIdx == 0 );
749675 SmallVector<unsigned > dpasInstShape = isOperandA
@@ -788,7 +714,6 @@ struct LoadOpConversion
788714
789715 Type packedDPASOperandType = LLVM::getFixedVectorType (
790716 loadResultElemType, packedElemsPerLanePerDPASInst);
791- llvm::errs () << " packed DPAS operand type: " << packedDPASOperandType << " \n " ;
792717
793718 // Outer dim: Dim M or N. Inner dim: Dim K.
794719 // Round the warp id fit into the tensor shape.
@@ -908,11 +833,9 @@ struct LoadOpConversion
908833 Value elemSizeInBytes = i32_val (originalElemBits / 8 );
909834
910835 ValueTable loadVals;
911- llvm::errs () << " Generating 2D block load for op: " << opIdx << " \n " ;
912836 for (int outer = 0 ; outer < numRepOuter; ++outer) {
913837 for (int rep = 0 ; rep < numLoadPerOutRepCluster; ++rep) {
914838 for (int k = 0 ; k < numRepInner; k += numOperandsInnerDimPerLoad) {
915- llvm::errs () << " outer, rep, k = " << outer << " , " << rep << " , " << k << " \n " ;
916839 Value offsetX, offsetY;
917840 if (opIdx == 0 ) {
918841 // A
0 commit comments