Skip to content

Commit 1c7f927

Browse files
Optimize attention prior computation: reuse qkv and cross_kv, dont copy scores to cpu
Signed-off-by: Viacheslav Klimkov <[email protected]>
1 parent 01a07f3 commit 1c7f927

File tree

5 files changed

+62
-71
lines changed

5 files changed

+62
-71
lines changed

cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ class RuntimeBuffers
287287
DecoderBuffers& decoderBuffers, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
288288
runtime::WorldConfig const& worldConfig);
289289

290-
std::vector<float> getScoresHost(runtime::TllmRuntime const& runtime);
291290
void setAttentionPriorIdx(RequestVector const& contextRequests, RequestVector const& genRequests,
292291
runtime::TllmRuntime const& runtime);
293292

cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
261261

262262
inputsIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
263263
if (useAttentionPrior) {
264-
scores = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
264+
scores = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
265265
}
266266
if (worldConfig.isPipelineParallel())
267267
{
@@ -919,23 +919,19 @@ void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, R
919919
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
920920
}
921921

922-
std::vector<float> RuntimeBuffers::getScoresHost(runtime::TllmRuntime const& runtime)
923-
{
924-
auto const& manager = runtime.getBufferManager();
925-
auto const& stream = runtime.getStream();
926-
std::vector<float> scoresHost;
927-
if (!useAttentionPrior) {
928-
TLLM_LOG_WARNING("Getting scores, when attention prior is disabled");
929-
return scoresHost;
930-
}
931-
auto scoresShape = scores->getShape();
932-
auto scoresSize = ITensor::volume(scoresShape);
933-
if (scoresSize > 0) {
934-
scoresHost.resize(scoresSize);
935-
manager.copy(*scores, scoresHost.data());
936-
stream.synchronize(); // Ensure copy completes
922+
template<typename T>
923+
static SizeType32 processScoresWithType(ITensor* scoresHost, SizeType32 prevPriorIdxLen) {
924+
auto* scoresHostPtr = bufferCast<T>(*scoresHost);
925+
T maxScore = scoresHostPtr[0];
926+
SizeType32 maxScoreIdx = 0;
927+
// Find the index with maximum score in the current subsection
928+
for (SizeType32 k = 1; k < prevPriorIdxLen; ++k) {
929+
if (scoresHostPtr[k] > maxScore) {
930+
maxScore = scoresHostPtr[k];
931+
maxScoreIdx = k;
932+
}
937933
}
938-
return scoresHost;
934+
return maxScoreIdx;
939935
}
940936

941937
void RuntimeBuffers::setAttentionPriorIdx(
@@ -961,47 +957,58 @@ void RuntimeBuffers::setAttentionPriorIdx(
961957
totalEncoderOutputLen += llmReq->getEncoderOutputLen();
962958
}
963959

964-
SizeType32 offset = 0;
960+
SizeType32 qOffset = 0;
965961
// we skip all context requests
966962
for (auto const& llmReq : contextRequests) {
967-
offset += llmReq->getContextChunkSize() * totalEncoderOutputLen;
963+
qOffset += llmReq->getContextChunkSize();
968964
// for context we just focusing at the beginning of the encoder sequence
969965
llmReq->setAttentionPriorIdx(0);
970966
}
971967

972-
std::vector<float> scoresHost = getScoresHost(runtime);
968+
// create a cpu buffer for scores to find max score in
969+
SizeType32 searchLength = 10;
970+
auto const& manager = runtime.getBufferManager();
971+
auto const& stream = runtime.getStream();
972+
auto scoresHost = manager.cpu(ITensor::makeShape({searchLength}), scores->getDataType());
973973

974974
// for generation requests, there is no context,
975975
// but we need to find correct section in (b * encoder_output_len)
976976
for (SizeType32 i = 0; i < (SizeType32)genRequests.size(); ++i) {
977977
// skip the context
978-
offset += totalContextEncoderOutputLen;
978+
SizeType32 kvOffset = totalContextEncoderOutputLen;
979979
for (SizeType32 j = 0; j < (SizeType32)genRequests.size(); ++j) {
980980
auto const& llmReq = genRequests[j];
981981
SizeType32 encoderOutputLen = llmReq->getEncoderOutputLen();
982982
if (i == j) {
983983
// find attnetion prior idx in range [prev_prior_idx; prev_prior_idx + 10]
984984
SizeType32 prevPriorIdx = llmReq->getAttentionPriorIdx();
985985
// ignore last 3 tokens, move strictly forward, look up to 10 tokens forward
986-
SizeType32 prevPriorIdxEnd = std::min(prevPriorIdx + 10, encoderOutputLen - 3);
987-
988-
// find maximum score and it's index in current subsection of scores buffer
989-
SizeType32 maxScoreIdx = prevPriorIdx;
990-
SizeType32 maxScore = scoresHost[offset + prevPriorIdx];
991-
992-
// Find the index with maximum score in the current subsection
993-
for (SizeType32 k = prevPriorIdx + 1; k < prevPriorIdxEnd; ++k) {
994-
if (scoresHost[offset + k] > maxScore) {
995-
maxScore = scoresHost[offset + k];
996-
maxScoreIdx = k;
997-
}
986+
SizeType32 prevPriorIdxEnd = std::min(prevPriorIdx + searchLength, encoderOutputLen);
987+
SizeType32 prevPriorIdxLen = prevPriorIdxEnd - prevPriorIdx;
988+
989+
// slice relevant section of scores
990+
auto scoresSlice = ITensor::slice(scores, {qOffset, kvOffset + prevPriorIdx}, prevPriorIdxLen);
991+
// copies and converts to float
992+
scoresHost->reshape(ITensor::makeShape({prevPriorIdxLen}));
993+
manager.copy(*scoresSlice, *scoresHost);
994+
stream.synchronize();
995+
996+
// find index of maximum score in the window
997+
SizeType32 maxScoreIdx = 0;
998+
if (scores->getDataType() == nvinfer1::DataType::kFLOAT) {
999+
maxScoreIdx = processScoresWithType<float>(scoresHost.get(), prevPriorIdxLen);
1000+
} else if (scores->getDataType() == nvinfer1::DataType::kHALF) {
1001+
maxScoreIdx = processScoresWithType<half>(scoresHost.get(), prevPriorIdxLen);
1002+
} else {
1003+
TLLM_LOG_WARNING("Unsupported scores data type");
9981004
}
9991005

10001006
// Set the attention prior index to the position with maximum score
1001-
llmReq->setAttentionPriorIdx(maxScoreIdx);
1007+
llmReq->setAttentionPriorIdx(prevPriorIdx + maxScoreIdx);
10021008
}
1003-
offset += encoderOutputLen;
1009+
kvOffset += encoderOutputLen;
10041010
}
1011+
qOffset += 1;
10051012
}
10061013
}
10071014

examples/models/contrib/t5tts/convert_checkpoint.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,6 @@ def convert_t5tts_decoder(
305305
model_dict[f't5_decoder.layers.{i}.cross_attention.q_net.weight'],
306306
model_dict[f't5_decoder.layers.{i}.cross_attention.kv_net.weight']
307307
], dim=0).contiguous()
308-
# projections to compute attention scores
309-
weights[f'decoder_layers.{i}.q_proj.weight'] = model_dict[
310-
f't5_decoder.layers.{i}.cross_attention.q_net.weight'].contiguous()
311-
kv_weight = model_dict[f't5_decoder.layers.{i}.cross_attention.kv_net.weight']
312-
dim = kv_weight.shape[0] // 2
313-
weights[f'decoder_layers.{i}.k_proj.weight'] = kv_weight[:dim, :]
314308

315309
weights[f'decoder_layers.{i}.cross_attention.qkv.weight'] = qkv_weight
316310
weights[f'decoder_layers.{i}.cross_attention.dense.weight'] = model_dict[

tensorrt_llm/layers/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,9 +1555,9 @@ def transpose_for_scores(x,
15551555
context = dense_conditional.add_output(skip_case, context)
15561556

15571557
if use_cache:
1558-
return (context, past_key_value)
1558+
return (context, qkv, cross_kv, past_key_value)
15591559
else:
1560-
return context
1560+
return (context, qkv, cross_kv)
15611561

15621562
def set_rel_attn_table(self, max_seq_len, precomputed_relative_attention):
15631563
self.rel_attn_table = Parameter(shape=(self.num_attention_heads,

tensorrt_llm/models/t5tts/model.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
LayerNormType, MLPType,
2626
PositionEmbeddingType, Tensor, assertion,
2727
concat, gather_last_token_logits, maximum,
28-
minimum, recv, select, send, shape, view, mean, add,
28+
minimum, recv, select, send, shape, view, mean, add, slice,
2929
squeeze, unsqueeze, transpose, matmul, stack, cast)
3030
from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
3131
AttentionMaskType, AttentionParams,
@@ -428,6 +428,7 @@ def __init__(self,
428428

429429
# e.g. BART post, T5 pre
430430
self.layernorm_position = layernorm_position
431+
self.hidden_size = hidden_size
431432

432433
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
433434
self.self_attention = Attention(
@@ -455,26 +456,6 @@ def __init__(self,
455456
eps=layernorm_eps,
456457
dtype=dtype, bias=False)
457458

458-
# to compute cross attention scores
459-
self.q_proj = ColumnLinear(
460-
hidden_size,
461-
hidden_size,
462-
bias=False,
463-
dtype=dtype,
464-
tp_group=mapping.tp_group,
465-
tp_size=mapping.tp_size,
466-
gather_output=True,
467-
)
468-
self.k_proj = ColumnLinear(
469-
hidden_size,
470-
hidden_size,
471-
bias=False,
472-
dtype=dtype,
473-
tp_group=mapping.tp_group,
474-
tp_size=mapping.tp_size,
475-
gather_output=True,
476-
)
477-
478459
# Note: self attn uses MMHA, mask is always causal triangular
479460
# cross attn has two scenarios:
480461
# - in context phase, all ones mask, same as padding type
@@ -558,18 +539,16 @@ def forward(self,
558539
kv_cache_params=kv_cache_params,
559540
attention_params=attention_params)
560541
if use_cache:
561-
attention_output, presents_self = attention_output
542+
attention_output, _, _, presents_self = attention_output
543+
else:
544+
attention_output, _, _ = attention_output
562545
hidden_states = residual + attention_output
563546

564547
# cross attention
565548
residual = hidden_states
566549

567550
hidden_states = self.cross_attention_layernorm(hidden_states)
568551
encoder_output = self.cross_attention_memory_layernorm(encoder_output)
569-
# compute attention scores
570-
q = cast(self.q_proj(hidden_states), "float32") # b * context x hidden
571-
k = cast(self.k_proj(encoder_output), "float32") # b * enc x hidden
572-
scores = matmul(q, k, transb=True) # b * context x b * enc
573552
attention_output = self.cross_attention(
574553
hidden_states=hidden_states,
575554
attention_mask=attention_mask_params.cross_attention_mask,
@@ -582,9 +561,21 @@ def forward(self,
582561
cross_kv_cache_gen=cross_kv_cache_gen,
583562
cross_kv_reuse=cross_kv_reuse)
584563
if use_cache:
585-
attention_output, presents_cross = attention_output
564+
attention_output, qkv, cross_kv, presents_cross = attention_output
565+
else:
566+
attention_output, qkv, cross_kv = attention_output
586567
hidden_states = residual + attention_output
587568

569+
# compute attention scores
570+
# TODO: assumes padding disabled
571+
q = slice(qkv, concat([0, 0]), concat([shape(qkv, 0), self.hidden_size]))
572+
k = slice(cross_kv, concat([0, 0]), concat([shape(cross_kv, 0), self.hidden_size]))
573+
scores = matmul(
574+
q,
575+
k,
576+
transb=True
577+
)
578+
588579
# conv ff (norm -> conv -> residual)
589580
residual = hidden_states
590581
hidden_states = self.pos_ff_layernorm(hidden_states)

0 commit comments

Comments
 (0)