Skip to content

Commit 4a4dd67

Browse files
committed
implemented swa mask for HMMA
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent a36b48b commit 4a4dd67

File tree

1 file changed

+64
-12
lines changed

1 file changed

+64
-12
lines changed

cpp/kernels/xqa/mha.cu

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -461,20 +461,53 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
461461
#define MMAS_N_PER_MASK 2
462462

463463
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
464-
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
464+
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
465+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
466+
,
467+
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
468+
#endif
469+
)
465470
{
466471
uint32_t const idxInQuad = laneId() % 4;
467472
uint32_t const idxQuad = laneId() / 4;
468473
// Packed mask is aligned with 32 bits (2 uint16_t).
469474
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
470475
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
476+
constexpr uint64_t fullMask = ~uint64_t{0};
477+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
478+
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
479+
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
480+
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
481+
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
482+
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
483+
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
484+
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
485+
#else
486+
constexpr bool ctaNeedBegMask = false;
487+
bool const ctaNeedSpecDecMask = true;
488+
int32_t const tok0NbMaskOut = -2147483648;
489+
#endif
490+
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
491+
492+
if (!needMask)
493+
{
494+
return;
495+
}
471496
#pragma unroll
472497
for (uint32_t m = 0; m < acc.rows; m++)
473498
{
474499
#pragma unroll
475500
for (uint32_t i = 0; i < InstAcc::rows; i++)
476501
{
477-
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
502+
uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8) / headGrpSize;
503+
uint32_t const tokenRow = min(idxQTokInCta, actualQSeqLen - 1);
504+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
505+
int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta);
506+
uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
507+
#else
508+
uint64_t const begMask = fullMask;
509+
#endif
510+
478511
#pragma unroll
479512
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
480513
{
@@ -486,12 +519,15 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
486519
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
487520
? 0u
488521
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
489-
uint32_t packedMask = 0u;
490522
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
491-
reinterpret_cast<uint16_t*>(&packedMask)[0]
492-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
493-
reinterpret_cast<uint16_t*>(&packedMask)[1]
494-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
523+
uint32_t packedMask = ~uint32_t{0};
524+
if (ctaNeedSpecDecMask)
525+
{
526+
reinterpret_cast<uint16_t*>(&packedMask)[0]
527+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
528+
reinterpret_cast<uint16_t*>(&packedMask)[1]
529+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
530+
}
495531
#pragma unroll
496532
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
497533
{
@@ -505,7 +541,11 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
505541
bool const maskFlag = col + actualQSeqLen < nbValidCols
506542
? true
507543
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
508-
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
544+
545+
bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true;
546+
547+
acc(m, n)(i, j)
548+
= maskFlag && begMaskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
509549
}
510550
}
511551
}
@@ -1606,8 +1646,14 @@ CUBIN_EXPORT __global__
16061646
#endif
16071647

16081648
uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1609-
#if SLIDING_WINDOW
1649+
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1650+
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
1651+
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
1652+
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
1653+
1654+
#elif SLIDING_WINDOW
16101655
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
1656+
assert(!SPEC_DEC || !rtIsReallySliding);
16111657
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;
16121658
#else
16131659
constexpr bool rtIsReallySliding = false;
@@ -1621,7 +1667,9 @@ CUBIN_EXPORT __global__
16211667
#endif
16221668

16231669
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
1624-
#if SPEC_DEC
1670+
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1671+
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1672+
#elif SPEC_DEC
16251673
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16261674
#endif
16271675

@@ -1907,8 +1955,12 @@ CUBIN_EXPORT __global__
19071955
if (seqIter >= nbSeqItersWithoutMask)
19081956
{
19091957
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
1910-
applyMaskFromInput(
1911-
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
1958+
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
1959+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1960+
,
1961+
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
1962+
#endif
1963+
);
19121964
}
19131965
#else
19141966
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);

0 commit comments

Comments
 (0)