@@ -466,20 +466,53 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
466466#define MMAS_N_PER_MASK 2
467467
468468__device__ inline void applyMaskFromInput (Warp const & warp, WarpAcc& acc, MaskType const * mask, uint32_t rowOffset,
469- uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
469+ uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
470+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
471+ ,
472+ int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
473+ #endif
474+ )
470475{
471476 uint32_t const idxInQuad = laneId () % 4 ;
472477 uint32_t const idxQuad = laneId () / 4 ;
473478 // Packed mask is aligned with 32 bits (2 uint16_t).
474479 uint32_t const nbPackedMasksPerRow = divUp (qSeqLen, 32u ) * 2u ;
475480 uint16_t const * uint16Mask = reinterpret_cast <uint16_t const *>(mask);
481+ constexpr uint64_t fullMask = ~uint64_t {0 };
482+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
483+ Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x };
484+ Range const maxMaskOutRange = {0 , mha::max (0 , tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1 )};
485+ bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end ;
486+ assert (ctaNeedBegMask == overlap (tileRange, maxMaskOutRange));
487+ int32_t const tok0NbMaskOut = int32_t (tok0WinBeg) - int32_t (warpTileTokenBeg);
488+ uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
489+ bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
490+ #else
491+ constexpr bool ctaNeedBegMask = false ;
492+ bool const ctaNeedSpecDecMask = true ;
493+ int32_t const tok0NbMaskOut = -2147483648 ;
494+ #endif
495+ bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
496+
497+ if (!needMask)
498+ {
499+ return ;
500+ }
476501#pragma unroll
477502 for (uint32_t m = 0 ; m < acc.rows ; m++)
478503 {
479504#pragma unroll
480505 for (uint32_t i = 0 ; i < InstAcc::rows; i++)
481506 {
482- uint32_t const tokenRow = min ((rowOffset + instM * m + idxQuad + i * 8 ) / headGrpSize, actualQSeqLen - 1 );
507+ uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8 ) / headGrpSize;
508+ uint32_t const tokenRow = min (idxQTokInCta, actualQSeqLen - 1 );
509+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
510+ int32_t const begNbMaskOut = tok0NbMaskOut + int32_t (idxQTokInCta);
511+ uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
512+ #else
513+ uint64_t const begMask = fullMask;
514+ #endif
515+
483516#pragma unroll
484517 for (uint32_t mask_n = 0 ; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
485518 {
@@ -491,12 +524,15 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
491524 uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
492525 ? 0u
493526 : min (lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1 );
494- uint32_t packedMask = 0u ;
495527 uint32_t const maskPosStart = (maskPos0 / 16 ) * 16 ;
496- reinterpret_cast <uint16_t *>(&packedMask)[0 ]
497- = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16 )];
498- reinterpret_cast <uint16_t *>(&packedMask)[1 ]
499- = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16 )];
528+ uint32_t packedMask = ~uint32_t {0 };
529+ if (ctaNeedSpecDecMask)
530+ {
531+ reinterpret_cast <uint16_t *>(&packedMask)[0 ]
532+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16 )];
533+ reinterpret_cast <uint16_t *>(&packedMask)[1 ]
534+ = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16 )];
535+ }
500536#pragma unroll
501537 for (uint32_t nj = 0 ; nj < MMAS_N_PER_MASK; nj++)
502538 {
@@ -510,7 +546,11 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
510546 bool const maskFlag = col + actualQSeqLen < nbValidCols
511547 ? true
512548 : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
513- acc (m, n)(i, j) = maskFlag && col < nbValidCols ? acc (m, n)(i, j) : safeInitRowMax;
549+
550+ bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true ;
551+
552+ acc (m, n)(i, j)
553+ = maskFlag && begMaskFlag && col < nbValidCols ? acc (m, n)(i, j) : safeInitRowMax;
514554 }
515555 }
516556 }
@@ -1611,8 +1651,14 @@ CUBIN_EXPORT __global__
16111651#endif
16121652
16131653 uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1614- #if SLIDING_WINDOW
1654+ #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1655+ uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
1656+ int32_t const tok0WinBeg = int32_t (tok0SeqLen) - int32_t (slidingWinSize);
1657+ uint32_t const nbTotalSkipTokens = mha::max (0 , tok0WinBeg);
1658+
1659+ #elif SLIDING_WINDOW
16151660 bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
1661+ assert (!SPEC_DEC || !rtIsReallySliding);
16161662 uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0 ;
16171663#else
16181664 constexpr bool rtIsReallySliding = false ;
@@ -1626,7 +1672,9 @@ CUBIN_EXPORT __global__
16261672#endif
16271673
16281674 uint32_t const nbSeqIters = useKVCache ? divUp (cacheSeqLen, ctaTile.x ) : 0 ;
1629- #if SPEC_DEC
1675+ #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1676+ uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1677+ #elif SPEC_DEC
16301678 uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x ;
16311679#endif
16321680
@@ -1912,8 +1960,12 @@ CUBIN_EXPORT __global__
19121960 if (seqIter >= nbSeqItersWithoutMask)
19131961 {
19141962 uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U );
1915- applyMaskFromInput (
1916- warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
1963+ applyMaskFromInput (warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
1964+ #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1965+ ,
1966+ tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
1967+ #endif
1968+ );
19171969 }
19181970#else
19191971 bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
0 commit comments