Skip to content

Commit 01a07f3

Browse files
transformerBuffers.cpp: fix setting cross attention mask based on attention prior
Signed-off-by: Viacheslav Klimkov <[email protected]>
1 parent 9827fc6 commit 01a07f3

File tree

1 file changed

+32
-37
lines changed

1 file changed

+32
-37
lines changed

cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -656,43 +656,6 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq
656656
numCopiedTokens++;
657657
numTokens++;
658658
}
659-
else if (useAttentionPrior && llmReq->hasAttentionPriorIdx() && maxEncoderInputLengthInBatch > 5)
660-
{
661-
auto focusTimeIdx = llmReq->getAttentionPriorIdx();
662-
//TODO: remove this debug print
663-
TLLM_LOG_WARNING("At decoding step %d, focusing on encoder time %d",
664-
(int)llmReq->getDecodingIter(), (int)focusTimeIdx
665-
);
666-
// Create a mask where only the token at focusTimeIdx is attended to
667-
// Use std::unique_ptr with bool[] instead of vector<bool> to avoid bit packing
668-
auto customMask = std::make_unique<bool[]>(maxEncoderInputLengthInBatch);
669-
// Initialize all elements to false
670-
std::fill_n(customMask.get(), maxEncoderInputLengthInBatch, false);
671-
if (focusTimeIdx < maxEncoderInputLengthInBatch) {
672-
// TODO: implement flooring, window, etc.
673-
if (!llmReq->isAttentionPriorStuck()) {
674-
customMask[focusTimeIdx] = true;
675-
}
676-
customMask[std::min(focusTimeIdx + 1, maxEncoderInputLengthInBatch - 1)] = true;
677-
customMask[std::min(focusTimeIdx + 2, maxEncoderInputLengthInBatch - 1)] = true;
678-
} else {
679-
std::fill_n(customMask.get(), maxEncoderInputLengthInBatch, true);
680-
TLLM_LOG_WARNING("Time index %d exceeds encoder length %d, no position will be attended to",
681-
focusTimeIdx, maxEncoderInputLengthInBatch);
682-
}
683-
684-
// Copy the custom mask to pinned memory
685-
std::memcpy(pinnedMemPtr, customMask.get(), maxEncoderInputLengthInBatch);
686-
pinnedMemPtr += maxEncoderInputLengthInBatch;
687-
688-
// Set up copy offsets the same way as before
689-
batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast<SizeType64>(pinnedMemPtr - primarySrcPtr);
690-
batchedCopyDstOffsets.begin()[numCopiedTokens] = numTokens * static_cast<SizeType64>(maxEncoderInputLengthInBatch);
691-
batchedCopySizes.begin()[numCopiedTokens] = maxEncoderInputLengthInBatch;
692-
693-
numCopiedTokens++;
694-
numTokens++;
695-
}
696659
else
697660
{
698661
numTokens++;
@@ -738,6 +701,38 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq
738701
sync_check_cuda_error(stream.get());
739702
}
740703

704+
// directly adjust crossAttentionMaskDevice according to the attention prior
705+
if (useAttentionPrior) {
706+
SizeType32 qSeqOffset = 0;
707+
for (auto const& llmReq : contextRequests) {
708+
qSeqOffset += llmReq->getContextChunkSize();
709+
}
710+
for (auto const& llmReq : genRequests) {
711+
if (llmReq->hasAttentionPriorIdx() && llmReq->getEncoderOutputLen() > 5) {
712+
// need to adjust the mask
713+
// 1. dont attend any of encoder tokens
714+
auto reqMaskSlice = ITensor::slice(crossAttentionMaskDevice, {qSeqOffset, 0}, maxEncoderInputLengthInBatch);
715+
manager.setMem(*reqMaskSlice, 0);
716+
// 2. except those around focus time idx
717+
auto focus = llmReq->getAttentionPriorIdx();
718+
// TODO: this is a bigger window than a NeMo one [focus, focus + 3)
719+
// but without attention_prior_floor, a bigger window is needed
720+
// to get any sort of reasonable synthesis
721+
auto from = std::max(0, focus - 2);
722+
auto to = std::min(focus + 6, llmReq->getEncoderOutputLen());
723+
TLLM_LOG_WARNING("For gen request %d, at decoding step %d, focusing on encoder time %d. window: [%d, %d)",
724+
llmReq->mRequestId, (int)llmReq->getDecodingIter(), (int)focus, (int)from, (int)to
725+
);
726+
auto len = to - from;
727+
auto reqFocusMaskSlice = ITensor::slice(crossAttentionMaskDevice, {qSeqOffset, from}, len);
728+
manager.setMem(*reqFocusMaskSlice, 1);
729+
}
730+
// for gen requests, the mask has shape [1, maxEncoderInputLengthInBatch]
731+
qSeqOffset += 1;
732+
}
733+
sync_check_cuda_error(stream.get());
734+
}
735+
741736
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
742737
}
743738

0 commit comments

Comments
 (0)