@@ -261,7 +261,7 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
261261
262262 inputsIds = manager.emptyTensor (MemoryType::kGPU , nvinfer1::DataType::kINT32 );
263263 if (useAttentionPrior) {
264- scores = manager.emptyTensor (MemoryType::kGPU , nvinfer1::DataType:: kFLOAT );
264+ scores = manager.emptyTensor (MemoryType::kGPU , modelConfig. getDataType () );
265265 }
266266 if (worldConfig.isPipelineParallel ())
267267 {
@@ -919,23 +919,19 @@ void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, R
919919 TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
920920}
921921
922- std::vector<float > RuntimeBuffers::getScoresHost (runtime::TllmRuntime const & runtime)
923- {
924- auto const & manager = runtime.getBufferManager ();
925- auto const & stream = runtime.getStream ();
926- std::vector<float > scoresHost;
927- if (!useAttentionPrior) {
928- TLLM_LOG_WARNING (" Getting scores, when attention prior is disabled" );
929- return scoresHost;
930- }
931- auto scoresShape = scores->getShape ();
932- auto scoresSize = ITensor::volume (scoresShape);
933- if (scoresSize > 0 ) {
934- scoresHost.resize (scoresSize);
935- manager.copy (*scores, scoresHost.data ());
936- stream.synchronize (); // Ensure copy completes
922+ template <typename T>
923+ static SizeType32 processScoresWithType (ITensor* scoresHost, SizeType32 prevPriorIdxLen) {
924+ auto * scoresHostPtr = bufferCast<T>(*scoresHost);
925+ T maxScore = scoresHostPtr[0 ];
926+ SizeType32 maxScoreIdx = 0 ;
927+ // Find the index with maximum score in the current subsection
928+ for (SizeType32 k = 1 ; k < prevPriorIdxLen; ++k) {
929+ if (scoresHostPtr[k] > maxScore) {
930+ maxScore = scoresHostPtr[k];
931+ maxScoreIdx = k;
932+ }
937933 }
938- return scoresHost ;
934+ return maxScoreIdx ;
939935}
940936
941937void RuntimeBuffers::setAttentionPriorIdx (
@@ -961,47 +957,58 @@ void RuntimeBuffers::setAttentionPriorIdx(
961957 totalEncoderOutputLen += llmReq->getEncoderOutputLen ();
962958 }
963959
964- SizeType32 offset = 0 ;
960+ SizeType32 qOffset = 0 ;
965961 // we skip all context requests
966962 for (auto const & llmReq : contextRequests) {
967- offset += llmReq->getContextChunkSize () * totalEncoderOutputLen ;
963+ qOffset += llmReq->getContextChunkSize ();
968964 // for context we just focusing at the beginning of the encoder sequence
969965 llmReq->setAttentionPriorIdx (0 );
970966 }
971967
972- std::vector<float > scoresHost = getScoresHost (runtime);
968+ // create a cpu buffer for scores to find max score in
969+ SizeType32 searchLength = 10 ;
970+ auto const & manager = runtime.getBufferManager ();
971+ auto const & stream = runtime.getStream ();
972+ auto scoresHost = manager.cpu (ITensor::makeShape ({searchLength}), scores->getDataType ());
973973
974974 // for generation requests, there is no context,
975975 // but we need to find correct section in (b * encoder_output_len)
976976 for (SizeType32 i = 0 ; i < (SizeType32)genRequests.size (); ++i) {
977977 // skip the context
978- offset + = totalContextEncoderOutputLen;
978+ SizeType32 kvOffset = totalContextEncoderOutputLen;
979979 for (SizeType32 j = 0 ; j < (SizeType32)genRequests.size (); ++j) {
980980 auto const & llmReq = genRequests[j];
981981 SizeType32 encoderOutputLen = llmReq->getEncoderOutputLen ();
982982 if (i == j) {
983983 // find attnetion prior idx in range [prev_prior_idx; prev_prior_idx + 10]
984984 SizeType32 prevPriorIdx = llmReq->getAttentionPriorIdx ();
985985 // ignore last 3 tokens, move strictly forward, look up to 10 tokens forward
986- SizeType32 prevPriorIdxEnd = std::min (prevPriorIdx + 10 , encoderOutputLen - 3 );
987-
988- // find maximum score and it's index in current subsection of scores buffer
989- SizeType32 maxScoreIdx = prevPriorIdx;
990- SizeType32 maxScore = scoresHost[offset + prevPriorIdx];
991-
992- // Find the index with maximum score in the current subsection
993- for (SizeType32 k = prevPriorIdx + 1 ; k < prevPriorIdxEnd; ++k) {
994- if (scoresHost[offset + k] > maxScore) {
995- maxScore = scoresHost[offset + k];
996- maxScoreIdx = k;
997- }
986+ SizeType32 prevPriorIdxEnd = std::min (prevPriorIdx + searchLength, encoderOutputLen);
987+ SizeType32 prevPriorIdxLen = prevPriorIdxEnd - prevPriorIdx;
988+
989+ // slice relevant section of scores
990+ auto scoresSlice = ITensor::slice (scores, {qOffset, kvOffset + prevPriorIdx}, prevPriorIdxLen);
991+ // copies and converts to float
992+ scoresHost->reshape (ITensor::makeShape ({prevPriorIdxLen}));
993+ manager.copy (*scoresSlice, *scoresHost);
994+ stream.synchronize ();
995+
996+ // find index of maximum score in the window
997+ SizeType32 maxScoreIdx = 0 ;
998+ if (scores->getDataType () == nvinfer1::DataType::kFLOAT ) {
999+ maxScoreIdx = processScoresWithType<float >(scoresHost.get (), prevPriorIdxLen);
1000+ } else if (scores->getDataType () == nvinfer1::DataType::kHALF ) {
1001+ maxScoreIdx = processScoresWithType<half>(scoresHost.get (), prevPriorIdxLen);
1002+ } else {
1003+ TLLM_LOG_WARNING (" Unsupported scores data type" );
9981004 }
9991005
10001006 // Set the attention prior index to the position with maximum score
1001- llmReq->setAttentionPriorIdx (maxScoreIdx);
1007+ llmReq->setAttentionPriorIdx (prevPriorIdx + maxScoreIdx);
10021008 }
1003- offset += encoderOutputLen;
1009+ kvOffset += encoderOutputLen;
10041010 }
1011+ qOffset += 1 ;
10051012 }
10061013}
10071014
0 commit comments