Skip to content

Commit 941a54c

Browse files
authored
[None][feat] Update the indexer topK (#9255)
Signed-off-by: Christina Zhang <[email protected]>
1 parent 286ace2 commit 941a54c

File tree

7 files changed

+145
-167
lines changed

7 files changed

+145
-167
lines changed

cpp/tensorrt_llm/kernels/IndexerTopK.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
namespace tensorrt_llm::kernels
2626
{
27-
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* outIndices, float* auxLogits,
28-
int* auxIndices, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
29-
int const stride1, int const next_n, int const index_topk = 2048, cudaStream_t const stream = 0);
27+
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux,
28+
int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
29+
int const stride1, int const next_n, int const topK = 2048, cudaStream_t const stream = 0);
3030

31-
void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* outIndices,
32-
int const numRows, int const numColumns, int const stride0, int const stride1, int const index_topk = 2048,
31+
void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices,
32+
int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048,
3333
cudaStream_t const stream = 0);
3434

3535
} // namespace tensorrt_llm::kernels

0 commit comments

Comments
 (0)