|
24 | 24 |
|
25 | 25 | namespace tensorrt_llm::kernels |
26 | 26 | { |
27 | | -void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* outIndices, float* auxLogits, |
28 | | - int* auxIndices, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, |
29 | | - int const stride1, int const next_n, int const index_topk = 2048, cudaStream_t const stream = 0); |
| 27 | +void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux, |
| 28 | + int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, |
| 29 | + int const stride1, int const next_n, int const topK = 2048, cudaStream_t const stream = 0); |
30 | 30 |
|
31 | | -void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* outIndices, |
32 | | - int const numRows, int const numColumns, int const stride0, int const stride1, int const index_topk = 2048, |
| 31 | +void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices, |
| 32 | + int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048, |
33 | 33 | cudaStream_t const stream = 0); |
34 | 34 |
|
35 | 35 | } // namespace tensorrt_llm::kernels |
0 commit comments