@@ -461,20 +461,53 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
461461#define MMAS_N_PER_MASK 2
462462
463463__device__ inline void applyMaskFromInput (Warp const & warp, WarpAcc& acc, MaskType const * mask, uint32_t rowOffset,
464- uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
464+ uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
465+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
466+ ,
467+ int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
468+ #endif
469+ )
465470{
466471 uint32_t const idxInQuad = laneId () % 4 ;
467472 uint32_t const idxQuad = laneId () / 4 ;
468473 // Packed mask is aligned with 32 bits (2 uint16_t).
469474 uint32_t const nbPackedMasksPerRow = divUp (qSeqLen, 32u ) * 2u ;
470475 uint16_t const * uint16Mask = reinterpret_cast <uint16_t const *>(mask);
476+ constexpr uint64_t fullMask = ~uint64_t {0 };
477+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
478+ Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x };
479+ Range const maxMaskOutRange = {0 , mha::max (0 , tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1 )};
480+ bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end ;
481+ assert (ctaNeedBegMask == overlap (tileRange, maxMaskOutRange));
482+ int32_t const tok0NbMaskOut = int32_t (tok0WinBeg) - int32_t (warpTileTokenBeg);
483+ uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
484+ bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
485+ #else
486+ constexpr bool ctaNeedBegMask = false ;
487+ bool const ctaNeedSpecDecMask = true ;
488+ int32_t const tok0NbMaskOut = -2147483648 ;
489+ #endif
490+ bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
491+
492+ if (!needMask)
493+ {
494+ return ;
495+ }
471496#pragma unroll
472497 for (uint32_t m = 0 ; m < acc.rows ; m++)
473498 {
474499#pragma unroll
475500 for (uint32_t i = 0 ; i < InstAcc::rows; i++)
476501 {
477- uint32_t const tokenRow = min ((rowOffset + instM * m + idxQuad + i * 8 ) / headGrpSize, actualQSeqLen - 1 );
502+ uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8 ) / headGrpSize;
503+ uint32_t const tokenRow = min (idxQTokInCta, actualQSeqLen - 1 );
504+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
505+ int32_t const begNbMaskOut = tok0NbMaskOut + int32_t (idxQTokInCta);
506+ uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
507+ #else
508+ uint64_t const begMask = fullMask;
509+ #endif
510+
478511#pragma unroll
479512 for (uint32_t mask_n = 0 ; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
480513 {
@@ -486,12 +519,15 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
486519 uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
487520 ? 0u
488521 : min (lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1 );
489- uint32_t packedMask = 0u ;
490522 uint32_t const maskPosStart = (maskPos0 / 16 ) * 16 ;
491- reinterpret_cast <uint16_t *>(&packedMask)[0 ]
492- = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16 )];
493- reinterpret_cast <uint16_t *>(&packedMask)[1 ]
494- = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16 )];
523+ uint32_t packedMask = ~uint32_t {0 };
524+ if (ctaNeedSpecDecMask)
525+ {
526+ reinterpret_cast <uint16_t *>(&packedMask)[0 ]
527+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16 )];
528+ reinterpret_cast <uint16_t *>(&packedMask)[1 ]
529+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16 )];
530+ }
495531#pragma unroll
496532 for (uint32_t nj = 0 ; nj < MMAS_N_PER_MASK; nj++)
497533 {
@@ -505,7 +541,11 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
505541 bool const maskFlag = col + actualQSeqLen < nbValidCols
506542 ? true
507543 : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
508- acc (m, n)(i, j) = maskFlag && col < nbValidCols ? acc (m, n)(i, j) : safeInitRowMax;
544+
545+ bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true ;
546+
547+ acc (m, n)(i, j)
548+ = maskFlag && begMaskFlag && col < nbValidCols ? acc (m, n)(i, j) : safeInitRowMax;
509549 }
510550 }
511551 }
@@ -1606,8 +1646,14 @@ CUBIN_EXPORT __global__
16061646#endif
16071647
16081648 uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1609- #if SLIDING_WINDOW
1649+ #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1650+ uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
1651+ int32_t const tok0WinBeg = int32_t (tok0SeqLen) - int32_t (slidingWinSize);
1652+ uint32_t const nbTotalSkipTokens = mha::max (0 , tok0WinBeg);
1653+
1654+ #elif SLIDING_WINDOW
16101655 bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
1656+ assert (!SPEC_DEC || !rtIsReallySliding);
16111657 uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0 ;
16121658#else
16131659 constexpr bool rtIsReallySliding = false ;
@@ -1621,7 +1667,9 @@ CUBIN_EXPORT __global__
16211667#endif
16221668
16231669 uint32_t const nbSeqIters = useKVCache ? divUp (cacheSeqLen, ctaTile.x ) : 0 ;
1624- #if SPEC_DEC
1670+ #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1671+ uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1672+ #elif SPEC_DEC
16251673 uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
16261674#endif
16271675
@@ -1907,8 +1955,12 @@ CUBIN_EXPORT __global__
19071955 if (seqIter >= nbSeqItersWithoutMask)
19081956 {
19091957 uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U );
1910- applyMaskFromInput (
1911- warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
1958+ applyMaskFromInput (warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
1959+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1960+ ,
1961+ tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
1962+ #endif
1963+ );
19121964 }
19131965#else
19141966 bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
0 commit comments