Skip to content

Commit 21f785c

Browse files
committed
implemented swa mask for HMMA
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 9b2abb8 commit 21f785c

File tree

1 file changed

+64
-12
lines changed

1 file changed

+64
-12
lines changed

cpp/kernels/xqa/mha.cu

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -466,20 +466,53 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
466466
#define MMAS_N_PER_MASK 2
467467

468468
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
469-
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
469+
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
470+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
471+
,
472+
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
473+
#endif
474+
)
470475
{
471476
uint32_t const idxInQuad = laneId() % 4;
472477
uint32_t const idxQuad = laneId() / 4;
473478
// Packed mask is aligned with 32 bits (2 uint16_t).
474479
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
475480
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
481+
constexpr uint64_t fullMask = ~uint64_t{0};
482+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
483+
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
484+
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
485+
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
486+
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
487+
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
488+
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
489+
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
490+
#else
491+
constexpr bool ctaNeedBegMask = false;
492+
bool const ctaNeedSpecDecMask = true;
493+
int32_t const tok0NbMaskOut = -2147483648;
494+
#endif
495+
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
496+
497+
if (!needMask)
498+
{
499+
return;
500+
}
476501
#pragma unroll
477502
for (uint32_t m = 0; m < acc.rows; m++)
478503
{
479504
#pragma unroll
480505
for (uint32_t i = 0; i < InstAcc::rows; i++)
481506
{
482-
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
507+
uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8) / headGrpSize;
508+
uint32_t const tokenRow = min(idxQTokInCta, actualQSeqLen - 1);
509+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
510+
int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta);
511+
uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
512+
#else
513+
uint64_t const begMask = fullMask;
514+
#endif
515+
483516
#pragma unroll
484517
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
485518
{
@@ -491,12 +524,15 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
491524
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
492525
? 0u
493526
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
494-
uint32_t packedMask = 0u;
495527
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
496-
reinterpret_cast<uint16_t*>(&packedMask)[0]
497-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
498-
reinterpret_cast<uint16_t*>(&packedMask)[1]
499-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
528+
uint32_t packedMask = ~uint32_t{0};
529+
if (ctaNeedSpecDecMask)
530+
{
531+
reinterpret_cast<uint16_t*>(&packedMask)[0]
532+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
533+
reinterpret_cast<uint16_t*>(&packedMask)[1]
534+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
535+
}
500536
#pragma unroll
501537
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
502538
{
@@ -510,7 +546,11 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
510546
bool const maskFlag = col + actualQSeqLen < nbValidCols
511547
? true
512548
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
513-
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
549+
550+
bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true;
551+
552+
acc(m, n)(i, j)
553+
= maskFlag && begMaskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
514554
}
515555
}
516556
}
@@ -1611,8 +1651,14 @@ CUBIN_EXPORT __global__
16111651
#endif
16121652

16131653
uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1614-
#if SLIDING_WINDOW
1654+
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1655+
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
1656+
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
1657+
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
1658+
1659+
#elif SLIDING_WINDOW
16151660
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
1661+
assert(!SPEC_DEC || !rtIsReallySliding);
16161662
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;
16171663
#else
16181664
constexpr bool rtIsReallySliding = false;
@@ -1626,7 +1672,9 @@ CUBIN_EXPORT __global__
16261672
#endif
16271673

16281674
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
1629-
#if SPEC_DEC
1675+
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1676+
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1677+
#elif SPEC_DEC
16301678
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16311679
#endif
16321680

@@ -1912,8 +1960,12 @@ CUBIN_EXPORT __global__
19121960
if (seqIter >= nbSeqItersWithoutMask)
19131961
{
19141962
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
1915-
applyMaskFromInput(
1916-
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
1963+
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
1964+
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1965+
,
1966+
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
1967+
#endif
1968+
);
19171969
}
19181970
#else
19191971
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);

0 commit comments

Comments
 (0)