@@ -499,7 +499,9 @@ struct LoadOpConversion
499499 auto tensorType = cast<RankedTensorType>(resultType);
500500
501501 // Only lower loadOp with dpas layout encoding.
502- if (!hasDotDpasEncoding (tensorType))
502+ auto encoding = tensorType.getEncoding ();
503+ const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+ if (!hasDpasLayout && !hasDotDpasEncoding (tensorType))
503505 return failure ();
504506
505507 Attribute blockIOAttr =
@@ -514,8 +516,8 @@ struct LoadOpConversion
514516 " Only row_major or column_major is supported" );
515517 const bool memoryRowMajor = (memoryLayoutInfo == " row_major" );
516518
517- DotOperandEncodingAttr dotLayout = getDotEncoding (tensorType).value ();
518- auto dotOrder = dotLayout .getThreadOrder ();
519+ auto dpasLayout = hasDpasLayout ? cast<DpasEncodingAttr>(encoding) : cast<DpasEncodingAttr>( getDotEncoding (tensorType).value (). getParent () );
520+ auto dotOrder = dpasLayout .getThreadOrder ();
519521 size_t rank = dotOrder.size ();
520522 const bool valueRowMajor =
521523 (dotOrder[rank - 2 ] == 1 && dotOrder[rank - 1 ] == 0 );
@@ -524,9 +526,17 @@ struct LoadOpConversion
524526 " Only row_major or column_major is allowed" );
525527 const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526528
527- auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent ());
529+ // auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
530+ auto getOpIdx = [&]() -> unsigned {
531+ if (hasDpasLayout) {
532+ return 2 ;
533+ } else {
534+ auto dotLayout = getDotEncoding (tensorType).value ();
535+ return dotLayout.getOpIdx ();
536+ }
537+ };
528538
529- const unsigned opIdx = dotLayout. getOpIdx ();
539+ const unsigned opIdx = getOpIdx ();
530540 Type eltTy = tensorType.getElementType ();
531541 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
532542 unsigned numElems = getTotalElemsPerThread (resultType);
@@ -543,6 +553,198 @@ struct LoadOpConversion
543553 SmallVector<Value> multiDimWarpId =
544554 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545555
556+ #if 1
557+ if (hasDpasLayout) {
558+ llvm::errs () << " rewriting tensor pointer load for dpas user!\n " ;
559+ #if 1
560+ MLIRContext *ctx = rewriter.getContext ();
561+
562+ Type eltTy = tensorType.getElementType ();
563+ llvm::errs () << " Element type: " << eltTy << " \n " ;
564+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
565+ Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
566+
567+ SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
568+ int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
569+ Type load2DGenXType =
570+ LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
571+ elemsPerLane); // make it opaque type.
572+ llvm::errs () << " load 2d gen x type: " << load2DGenXType << " \n " ;
573+
574+ auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
575+ offsetBaseY] =
576+ getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
577+ baseWidth = trunc (i32_ty, baseWidth);
578+ baseHeight = trunc (i32_ty, baseHeight);
579+
580+ // always row order coming out of DPAS
581+ auto pitch = trunc (i32_ty, rowStride);
582+
583+ SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
584+ unsigned outerDimWarpNum = std::min<unsigned >(
585+ warpsPerCTA[rank - 2 ],
586+ mlir::ceil<unsigned >(tensorShape[rank - 2 ], repClusterShape[rank - 2 ]));
587+ unsigned innerDimWarpNum = std::min<unsigned >(
588+ warpsPerCTA[rank - 1 ],
589+ mlir::ceil<unsigned >(tensorShape[rank - 1 ], repClusterShape[rank - 1 ]));
590+ Value outerDimWarpId =
591+ urem (multiDimWarpId[rank - 2 ], i32_val (outerDimWarpNum));
592+ Value innerDimWarpId =
593+ urem (multiDimWarpId[rank - 1 ], i32_val (innerDimWarpNum));
594+ int64_t numRepOuter = numReps[1 ];
595+ int64_t numRepInner = numReps[2 ];
596+
597+
598+ std::array<unsigned , 2 > replicaStride = {
599+ outerDimWarpNum * repClusterShape[rank - 2 ],
600+ innerDimWarpNum * repClusterShape[rank - 1 ]};
601+ std::array<unsigned , 2 > warpStride = {repClusterShape[rank - 2 ],
602+ repClusterShape[rank - 1 ]};
603+
604+ Value dimWarpId0 = mul (outerDimWarpId, i32_val (warpStride[0 ]));
605+ Value dimWarpId1 = mul (innerDimWarpId, i32_val (warpStride[1 ]));
606+ Value warpId0Offset = add (dimWarpId0, offsetBaseY);
607+ Value warpId1Offset = add (dimWarpId1, offsetBaseX);
608+
609+ llvm::errs () << " elemsPerInstr: " << elemsPerInstr[0 ] << " , " << elemsPerInstr[1 ] << " \n " ;
610+ ArrayRef<unsigned > repCluster = dpasLayout.getRepCluster ();
611+ unsigned valOffset = 0 ;
612+
613+ SmallVector<Value> unpackedLoadedVals;
614+
615+ for (int m = 0 ; m < numRepOuter; ++m) {
616+ for (int n = 0 ; n < numRepInner; ++n) {
617+ for (int repM = 0 ; repM < repCluster[0 ]; ++repM) {
618+
619+ Value offsetY = add (warpId0Offset, i32_val (m * replicaStride[0 ] +
620+ repM * elemsPerInstr[0 ]));
621+ for (int repN = 0 ; repN < repCluster[1 ]; ++repN) {
622+ llvm::errs () << " m, n, repM, repN: " << m << " , " << n << " , " << repM << " , " << repN << " \n " ;
623+ Value offsetX =
624+ add (warpId1Offset,
625+ i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
626+
627+ assert (!isTransposeRequired);
628+ auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
629+ loc, load2DGenXType,
630+ /* ptr*/ base,
631+ /* base_width*/ mul (baseWidth, elemSizeInBytes),
632+ /* base_height*/ baseHeight,
633+ /* base_pitch*/ mul (pitch, elemSizeInBytes),
634+ /* x*/ trunc (i32_ty, offsetX),
635+ /* y*/ trunc (i32_ty, offsetY),
636+ /* elem_size_in_bits*/ elemSizeInBits,
637+ /* tile_width*/ elemsPerInstr[1 ],
638+ /* tile_height*/ elemsPerInstr[0 ],
639+ /* v_blocks*/ 1 ,
640+ /* transpose*/ isTransposeRequired,
641+ /* vnni_transform*/ false /*
642+ (usePackedType && !isOperandA && !isTransposeRequired &&
643+ eltTy.getIntOrFloatBitWidth() != 32)*/ );
644+ if (failed (load2dOp.verify ())) {
645+ // Explicitly invoke verifier because `triton_gen` ops are
646+ // immediately lowered further to a builtin call.
647+ return failure ();
648+ }
649+
650+
651+ #if 0
652+ llvm::errs() << "elemsPerLane: " << elemsPerLane << "\n";
653+ SmallVector<int32_t> indices(elemsPerLane);
654+ for (int elemIdx = 0; elemIdx < elemsPerLane;
655+ ++elemIdx) {
656+ indices[elemIdx] = elemIdx * n;
657+ }
658+
659+ #if 0
660+ llvm::errs() << "indices: ";
661+ for (size_t i = 0; i < indices.size(); i++) {
662+ llvm::errs() << " " << i;
663+ }
664+ llvm::errs() << "\n";
665+ #endif
666+ DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices);
667+ Value loadVal = rewriter.create<LLVM::ShuffleVectorOp>(
668+ loc, load2DGenXType, load2dOp, load2dOp, attr);
669+ #endif
670+ Value ret = bitcast (load2dOp, LLVM::getFixedVectorType (eltTy,
671+ elemsPerLane));
672+ llvm::errs () << " ret: " << ret << " \n " ;
673+ // each load should give us one column
674+ for (size_t i = 0 ; i < elemsPerLane; i++) {
675+ Value loaded =
676+ extract_element (eltTy, ret, i32_val (i));
677+ unpackedLoadedVals.push_back (loaded);
678+ }
679+
680+ // loadVals[{outer * packedRowNum * numLoadPerOutRepCluster +
681+ // rep * packedRowNum + row,
682+ // k + vblk * packedColNumPerVBlock + col}] =
683+ // bitcast(loadVal, unpackedDPASOperandType);
684+ #if 0
685+ Value storeVal = rewriter.create<LLVM::UndefOp>(
686+ loc, LLVM::getFixedVectorType(typeConverter->convertType(eltTy),
687+ elemsPerLane));
688+ for (size_t i = 0; i < elemsPerLane; ++i) {
689+ storeVal = insert_element(storeVal, vals[valOffset], i32_val(i));
690+ ++valOffset;
691+ }
692+
693+ auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
694+ loc,
695+ /*ptr*/ base,
696+ /*base_width*/ baseWidth,
697+ /*base_height*/ height,
698+ /*base_pitch*/ basePitch,
699+ /*x*/ trunc(i32_ty, offsetX),
700+ /*y*/ trunc(i32_ty, offsetY),
701+ /*elem_size_in_bits*/ elemSizeInBits,
702+ /*tile_width*/ elemsPerInstr[1],
703+ /*tile_height*/ elemsPerInstr[0],
704+ /*v_blocks*/ 1,
705+ /*stored_val*/ bitcast(storeVal, store2DGenXType));
706+
707+ if (failed(newOp.verify())) {
708+ // Explicitly invoke verifier because `triton_gen` ops are
709+ // immediately lowered further to a builtin call.
710+ return failure();
711+ }
712+ #endif
713+ }
714+ }
715+ }
716+ }
717+ #if 0
718+ return failure();
719+ #else
720+ TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
721+ Type llvmResultStructTy = typeConverter->convertType (op.getType ());
722+ Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
723+ rewriter, llvmResultStructTy);
724+ rewriter.replaceOp (op, {resultStruct});
725+
726+ return success ();
727+ #endif
728+ #else
729+
730+ ValueTable loadVals;
731+
732+ unsigned numRepOuter = numReps[2];
733+ // TODO: calculate this instead of guessing
734+ numOperandsPer2DLoadM = 1;
735+ numOperandsPer2DloadN = 1;
736+ for (int outer = 0; outer < numRepOuter; ++outer) {
737+ for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
738+ for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) {
739+ llvm::errs() << "load: " << outer << ", " << rep << ", " << k << "\n";
740+ }
741+ }
742+ }
743+ return failure();
744+ #endif
745+ }
746+ #endif
747+
546748 bool isOperandA = (opIdx == 0 );
547749 SmallVector<unsigned > dpasInstShape = isOperandA
548750 ? dpasLayout.getDPASInstShapeA ()
@@ -586,6 +788,7 @@ struct LoadOpConversion
586788
587789 Type packedDPASOperandType = LLVM::getFixedVectorType (
588790 loadResultElemType, packedElemsPerLanePerDPASInst);
791+ llvm::errs () << " packed DPAS operand type: " << packedDPASOperandType << " \n " ;
589792
590793 // Outer dim: Dim M or N. Inner dim: Dim K.
591794 // Round the warp id fit into the tensor shape.
@@ -705,9 +908,11 @@ struct LoadOpConversion
705908 Value elemSizeInBytes = i32_val (originalElemBits / 8 );
706909
707910 ValueTable loadVals;
911+ llvm::errs () << " Generating 2D block load for op: " << opIdx << " \n " ;
708912 for (int outer = 0 ; outer < numRepOuter; ++outer) {
709913 for (int rep = 0 ; rep < numLoadPerOutRepCluster; ++rep) {
710914 for (int k = 0 ; k < numRepInner; k += numOperandsInnerDimPerLoad) {
915+ llvm::errs () << " outer, rep, k = " << outer << " , " << rep << " , " << k << " \n " ;
711916 Value offsetX, offsetY;
712917 if (opIdx == 0 ) {
713918 // A
0 commit comments