@@ -529,7 +529,6 @@ struct LoadOpConversion
529529 " Only row_major or column_major is allowed" );
530530 const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
531531
532- // auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
533532 auto getOpIdx = [&]() -> unsigned {
534533 if (hasDpasLayout) {
535534 return 2 ;
@@ -541,6 +540,8 @@ struct LoadOpConversion
541540
542541 const unsigned opIdx = getOpIdx ();
543542 Type eltTy = tensorType.getElementType ();
543+ unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
544+
544545 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
545546 unsigned numElems = getTotalElemsPerThread (resultType);
546547 SmallVector<int64_t > numReps =
@@ -556,30 +557,30 @@ struct LoadOpConversion
556557 SmallVector<Value> multiDimWarpId =
557558 delinearize (rewriter, loc, warpId, warpsPerCTA, dpasOrder);
558559
559- // TODO: de-duplicate with above and below code
560560 if (hasDpasLayout) {
561- // llvm::errs() << "rewriting tensor pointer load for dpas user!\n";
561+ if (isTransposeRequired) {
562+ // TODO: this would likely require a shuffle to match the expected
563+ // ordering coming out of the DPAS layout and requires more
564+ // investigation
565+ return failure ();
566+ }
567+
562568 MLIRContext *ctx = rewriter.getContext ();
563569
564- Type eltTy = tensorType.getElementType ();
565- // llvm::errs() << "Element type: " << eltTy << "\n";
566- unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
567570 Value elemSizeInBytes = i32_val (elemSizeInBits / 8 );
568571
569572 SmallVector<unsigned > elemsPerInstr = dpasLayout.getDPASInstShapeC ();
570573 int64_t elemsPerLane = product<unsigned >(elemsPerInstr) / threadsPerWarp;
571574 Type load2DGenXType =
572575 LLVM::getFixedVectorType (IntegerType::get (ctx, elemSizeInBits),
573576 elemsPerLane); // make it opaque type.
574- // llvm::errs() << "load 2d gen x type: " << load2DGenXType << "\n";
575577
576578 auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
577579 offsetBaseY] =
578580 getValuesFromBlockPointerStruct (adaptor.getPtr (), rewriter);
579581 baseWidth = trunc (i32_ty, baseWidth);
580582 baseHeight = trunc (i32_ty, baseHeight);
581583
582- // always row order coming out of DPAS
583584 auto pitch = trunc (i32_ty, rowStride);
584585
585586 SmallVector<unsigned > repClusterShape = dpasLayout.getShapeC ();
@@ -626,7 +627,6 @@ struct LoadOpConversion
626627 add (warpId1Offset,
627628 i32_val (n * replicaStride[1 ] + repN * elemsPerInstr[1 ]));
628629
629- assert (!isTransposeRequired);
630630 auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
631631 loc, load2DGenXType,
632632 /* ptr*/ base,
@@ -636,13 +636,11 @@ struct LoadOpConversion
636636 /* x*/ trunc (i32_ty, offsetX),
637637 /* y*/ trunc (i32_ty, offsetY),
638638 /* elem_size_in_bits*/ elemSizeInBits,
639- /* tile_width*/ elemsPerInstr[1 ],
640- /* tile_height*/ elemsPerInstr[0 ],
639+ /* tile_width*/ elemsPerInstr[1 ],
640+ /* tile_height*/ elemsPerInstr[0 ],
641641 /* v_blocks*/ 1 ,
642- /* transpose*/ isTransposeRequired,
643- /* vnni_transform*/ false /*
644- (usePackedType && !isOperandA && !isTransposeRequired &&
645- eltTy.getIntOrFloatBitWidth() != 32)*/ );
642+ /* transpose*/ false ,
643+ /* vnni_transform*/ false );
646644 if (failed (load2dOp.verify ())) {
647645 // Explicitly invoke verifier because `triton_gen` ops are
648646 // immediately lowered further to a builtin call.
@@ -651,8 +649,7 @@ struct LoadOpConversion
651649
652650 Value ret = bitcast (
653651 load2dOp, LLVM::getFixedVectorType (eltTy, elemsPerLane));
654- // llvm::errs() << "ret: " << ret << "\n";
655- // each load should give us one column
652+
656653 for (size_t i = 0 ; i < elemsPerLane; i++) {
657654 Value loaded = extract_element (eltTy, ret, i32_val (i));
658655 unpackedLoadedVals.push_back (loaded);
@@ -701,11 +698,11 @@ struct LoadOpConversion
701698 // input operands to DPAS.
702699 // TODO: add support for int4 and int2.
703700 unsigned opsPerChannel = dpasLayout.getOpsPerChannel ();
704- unsigned elemBits = eltTy. getIntOrFloatBitWidth ();
705- if (( opsPerChannel == 4 && elemBits == 8 ) ||
706- (opsPerChannel == 2 && elemBits == 16 ) ||
707- (opsPerChannel == 1 && elemBits == 32 )) {
708- loadResultElemType = (isOperandA && elemBits != 32 ) ? i16_ty : i32_ty;
701+ if ((opsPerChannel == 4 && elemSizeInBits == 8 ) ||
702+ ( opsPerChannel == 2 && elemSizeInBits == 16 ) ||
703+ (opsPerChannel == 1 && elemSizeInBits == 32 )) {
704+ loadResultElemType =
705+ (isOperandA && elemSizeInBits != 32 ) ? i16_ty : i32_ty;
709706 packedElemsPerLanePerDPASInst =
710707 isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1 )
711708 : elemsPerLanePerDPASInst / opsPerChannel;
@@ -779,7 +776,7 @@ struct LoadOpConversion
779776
780777 // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
781778 // by enlarging the vBlocks.
782- unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8 ;
779+ unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8 ;
783780 numOperandsPer2DloadN =
784781 std::min (numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
785782 vBlocks = numOperandsPer2DloadN;
@@ -823,12 +820,12 @@ struct LoadOpConversion
823820 baseWidth = trunc (i32_ty, baseWidth);
824821 baseHeight = trunc (i32_ty, baseHeight);
825822
826- unsigned originalElemBits = elemBits ;
823+ const unsigned originalElemBits = elemSizeInBits ;
827824 if (isTransposeRequired) {
828825 // adjust the block io parameter to align HW's limitations on
829826 // transposing load.
830827 tileWidth = tileWidth / (32 / originalElemBits);
831- elemBits = 32 ;
828+ elemSizeInBits = 32 ;
832829 }
833830 Value elemSizeInBytes = i32_val (originalElemBits / 8 );
834831
@@ -872,14 +869,14 @@ struct LoadOpConversion
872869 /* base_pitch*/ mul (pitch, elemSizeInBytes),
873870 /* x*/ trunc (i32_ty, offsetX),
874871 /* y*/ trunc (i32_ty, offsetY),
875- /* elem_size_in_bits*/ elemBits ,
872+ /* elem_size_in_bits*/ elemSizeInBits ,
876873 /* tile_width*/ tileWidth,
877874 /* tile_height*/ tileHeight,
878875 /* v_blocks*/ vBlocks,
879876 /* transpose*/ isTransposeRequired,
880877 /* vnni_transform*/
881878 (usePackedType && !isOperandA && !isTransposeRequired &&
882- eltTy. getIntOrFloatBitWidth () != 32 ));
879+ originalElemBits != 32 ));
883880 if (failed (load2dOp.verify ())) {
884881 // Explicitly invoke verifier because `triton_gen` ops are
885882 // immediately lowered further to a builtin call.
0 commit comments