@@ -656,43 +656,6 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq
656656 numCopiedTokens++;
657657 numTokens++;
658658 }
659- else if (useAttentionPrior && llmReq->hasAttentionPriorIdx () && maxEncoderInputLengthInBatch > 5 )
660- {
661- auto focusTimeIdx = llmReq->getAttentionPriorIdx ();
662- // TODO: remove this debug print
663- TLLM_LOG_WARNING (" At decoding step %d, focusing on encoder time %d" ,
664- (int )llmReq->getDecodingIter (), (int )focusTimeIdx
665- );
666- // Create a mask where only the token at focusTimeIdx is attended to
667- // Use std::unique_ptr with bool[] instead of vector<bool> to avoid bit packing
668- auto customMask = std::make_unique<bool []>(maxEncoderInputLengthInBatch);
669- // Initialize all elements to false
670- std::fill_n (customMask.get (), maxEncoderInputLengthInBatch, false );
671- if (focusTimeIdx < maxEncoderInputLengthInBatch) {
672- // TODO: implement flooring, window, etc.
673- if (!llmReq->isAttentionPriorStuck ()) {
674- customMask[focusTimeIdx] = true ;
675- }
676- customMask[std::min (focusTimeIdx + 1 , maxEncoderInputLengthInBatch - 1 )] = true ;
677- customMask[std::min (focusTimeIdx + 2 , maxEncoderInputLengthInBatch - 1 )] = true ;
678- } else {
679- std::fill_n (customMask.get (), maxEncoderInputLengthInBatch, true );
680- TLLM_LOG_WARNING (" Time index %d exceeds encoder length %d, no position will be attended to" ,
681- focusTimeIdx, maxEncoderInputLengthInBatch);
682- }
683-
684- // Copy the custom mask to pinned memory
685- std::memcpy (pinnedMemPtr, customMask.get (), maxEncoderInputLengthInBatch);
686- pinnedMemPtr += maxEncoderInputLengthInBatch;
687-
688- // Set up copy offsets the same way as before
689- batchedCopySrcOffsets.begin ()[numCopiedTokens] = static_cast <SizeType64>(pinnedMemPtr - primarySrcPtr);
690- batchedCopyDstOffsets.begin ()[numCopiedTokens] = numTokens * static_cast <SizeType64>(maxEncoderInputLengthInBatch);
691- batchedCopySizes.begin ()[numCopiedTokens] = maxEncoderInputLengthInBatch;
692-
693- numCopiedTokens++;
694- numTokens++;
695- }
696659 else
697660 {
698661 numTokens++;
@@ -738,6 +701,38 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq
738701 sync_check_cuda_error (stream.get ());
739702 }
740703
704+ // directly adjust crossAttentionMaskDevice according to the attention prior
705+ if (useAttentionPrior) {
706+ SizeType32 qSeqOffset = 0 ;
707+ for (auto const & llmReq : contextRequests) {
708+ qSeqOffset += llmReq->getContextChunkSize ();
709+ }
710+ for (auto const & llmReq : genRequests) {
711+ if (llmReq->hasAttentionPriorIdx () && llmReq->getEncoderOutputLen () > 5 ) {
712+ // need to adjust the mask
713+ // 1. dont attend any of encoder tokens
714+ auto reqMaskSlice = ITensor::slice (crossAttentionMaskDevice, {qSeqOffset, 0 }, maxEncoderInputLengthInBatch);
715+ manager.setMem (*reqMaskSlice, 0 );
716+ // 2. except those around focus time idx
717+ auto focus = llmReq->getAttentionPriorIdx ();
718+ // TODO: this is a bigger window than a NeMo one [focus, focus + 3)
719+ // but without attention_prior_floor, a bigger window is needed
720+ // to get any sort of reasonable synthesis
721+ auto from = std::max (0 , focus - 2 );
722+ auto to = std::min (focus + 6 , llmReq->getEncoderOutputLen ());
723+ TLLM_LOG_WARNING (" For gen request %d, at decoding step %d, focusing on encoder time %d. window: [%d, %d)" ,
724+ llmReq->mRequestId , (int )llmReq->getDecodingIter (), (int )focus, (int )from, (int )to
725+ );
726+ auto len = to - from;
727+ auto reqFocusMaskSlice = ITensor::slice (crossAttentionMaskDevice, {qSeqOffset, from}, len);
728+ manager.setMem (*reqFocusMaskSlice, 1 );
729+ }
730+ // for gen requests, the mask has shape [1, maxEncoderInputLengthInBatch]
731+ qSeqOffset += 1 ;
732+ }
733+ sync_check_cuda_error (stream.get ());
734+ }
735+
741736 TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
742737}
743738
0 commit comments