Skip to content

Commit 7654d08

Browse files
committed
fix: nanobind bindings
Signed-off-by: Robin Kobus <[email protected]>
1 parent 54a3bff commit 7654d08

File tree

6 files changed

+118
-56
lines changed

6 files changed

+118
-56
lines changed

cpp/tensorrt_llm/nanobind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(TRTLLM_NB_MODULE
66
set(SRCS
77
batch_manager/algorithms.cpp
88
batch_manager/bindings.cpp
9+
batch_manager/buffers.cpp
910
batch_manager/cacheTransceiver.cpp
1011
batch_manager/kvCacheManager.cpp
1112
batch_manager/llmRequest.cpp

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -384,39 +384,6 @@ void initBindings(nb::module_& m)
384384
.def(nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
385385
nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"));
386386

387-
nb::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
388-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(), nb::arg("max_batch_size"),
389-
nb::arg("max_tokens_per_engine_step"), nb::arg("manager"))
390-
.def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
391-
.def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
392-
.def_rw("fill_values", &tb::DecoderInputBuffers::fillValues)
393-
.def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
394-
.def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
395-
.def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
396-
.def_rw("logits", &tb::DecoderInputBuffers::logits)
397-
.def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests);
398-
399-
nb::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
400-
.def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)
401-
.def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost)
402-
.def_prop_ro("new_output_tokens_host",
403-
[](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); })
404-
.def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost)
405-
.def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost)
406-
.def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost);
407-
408-
nb::class_<tb::SlotDecoderBuffers>(m, "SlotDecoderBuffers")
409-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&>(),
410-
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"))
411-
.def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds)
412-
.def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost)
413-
.def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost)
414-
.def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs)
415-
.def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost)
416-
.def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs)
417-
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
418-
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
419-
420387
m.def(
421388
"add_new_tokens_to_requests",
422389
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
@@ -435,10 +402,10 @@ void initBindings(nb::module_& m)
435402

436403
m.def(
437404
"make_decoding_batch_input",
438-
[](std::vector<std::shared_ptr<tb::LlmRequest>>& contextRequests,
439-
std::vector<std::shared_ptr<tb::LlmRequest>>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth,
440-
std::vector<int> const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers,
441-
runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager)
405+
[](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState,
406+
std::vector<std::shared_ptr<tb::LlmRequest>> const& contextRequests,
407+
std::vector<std::shared_ptr<tb::LlmRequest>> const& genRequests, tr::ITensor::SharedPtr const& logits,
408+
int beamWidth, std::vector<int> const& numContextLogitsPrefixSum, tr::BufferManager const& manager)
442409
{
443410
std::vector<int> activeSlots;
444411
std::vector<int> generationSteps;
@@ -496,21 +463,18 @@ void initBindings(nb::module_& m)
496463
batchSlotsRange[i] = activeSlots[i];
497464
}
498465

499-
auto decodingInput = std::make_unique<tr::decoder_batch::Input>(logitsVec, 1);
500-
decodingInput->batchSlots = batchSlots;
466+
decoderInputBuffers.batchLogits = logitsVec;
501467

502468
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
503469
if (maxBeamWidth > 1)
504470
{
505471
// For Variable-Beam-Width-Search
506472
decoderState.getJointDecodingInput().generationSteps = generationSteps;
507473
}
508-
509-
return decodingInput;
510474
},
511-
nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"),
512-
nb::arg("num_context_logits_prefix_sum"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"),
513-
nb::arg("buffer_manager"), "Make decoding batch input.");
475+
nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), nb::arg("context_requests"),
476+
nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"),
477+
nb::arg("num_context_logits_prefix_sum"), nb::arg("buffer_manager"), "Make decoding batch input.");
514478
}
515479

516480
} // namespace tensorrt_llm::nanobind::batch_manager
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "buffers.h"
19+
#include "tensorrt_llm/nanobind/common/customCasters.h"
20+
21+
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
22+
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
23+
24+
#include <ATen/ATen.h>
25+
#include <nanobind/nanobind.h>
26+
#include <nanobind/operators.h>
27+
#include <torch/extension.h>
28+
29+
namespace nb = nanobind;
30+
namespace tb = tensorrt_llm::batch_manager;
31+
namespace tr = tensorrt_llm::runtime;
32+
33+
using tr::SizeType32;
34+
35+
namespace tensorrt_llm::nanobind::batch_manager
36+
{
37+
38+
void Buffers::initBindings(nb::module_& m)
39+
{
40+
nb::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
41+
.def(nb::init<runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(), nb::arg("max_batch_size"),
42+
nb::arg("max_tokens_per_engine_step"), nb::arg("manager"))
43+
.def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
44+
.def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
45+
.def_rw("fill_values", &tb::DecoderInputBuffers::fillValues)
46+
.def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
47+
.def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
48+
.def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
49+
.def_rw("decoder_logits", &tb::DecoderInputBuffers::decoderLogits)
50+
.def_rw("decoder_requests", &tb::DecoderInputBuffers::decoderRequests);
51+
52+
nb::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
53+
.def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)
54+
.def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost)
55+
.def_prop_ro("new_output_tokens_host",
56+
[](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); })
57+
.def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost)
58+
.def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost)
59+
.def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost);
60+
61+
nb::class_<tb::SlotDecoderBuffers>(m, "SlotDecoderBuffers")
62+
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&>(),
63+
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"))
64+
.def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds)
65+
.def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost)
66+
.def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost)
67+
.def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs)
68+
.def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost)
69+
.def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs)
70+
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
71+
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
72+
}
73+
} // namespace tensorrt_llm::nanobind::batch_manager
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include <nanobind/nanobind.h>
21+
namespace nb = nanobind;
22+
23+
namespace tensorrt_llm::nanobind::batch_manager
24+
{
25+
class Buffers
26+
{
27+
public:
28+
static void initBindings(nb::module_& m);
29+
};
30+
} // namespace tensorrt_llm::nanobind::batch_manager

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "tensorrt_llm/common/quantization.h"
3434
#include "tensorrt_llm/nanobind/batch_manager/algorithms.h"
3535
#include "tensorrt_llm/nanobind/batch_manager/bindings.h"
36+
#include "tensorrt_llm/nanobind/batch_manager/buffers.h"
3637
#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h"
3738
#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h"
3839
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
@@ -469,6 +470,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
469470
.def_prop_ro("pinned", &tr::MemoryCounters::getPinned)
470471
.def_prop_ro("uvm", &tr::MemoryCounters::getUVM);
471472

473+
tpb::Buffers::initBindings(mInternalBatchManager);
472474
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
473475
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
474476
tpb::initBindings(mInternalBatchManager);

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ void initBindings(nb::module_& m)
116116
.def_rw("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer)
117117
.def(nb::self == nb::self);
118118

119-
nb::class_<tr::BufferManager>(m, "BufferManager")
120-
.def(nb::init<tr::BufferManager::CudaStreamPtr, bool>(), nb::arg("stream"), nb::arg("trim_pool") = false)
121-
.def_prop_ro("stream", &tr::BufferManager::getStream);
122-
123119
nb::class_<tr::TllmRuntime>(m, "TllmRuntime")
124120
.def(
125121
"__init__",
@@ -157,14 +153,6 @@ void initBindings(nb::module_& m)
157153
.def_prop_ro("logits_dtype_from_engine",
158154
[](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); });
159155

160-
nb::class_<tr::decoder_batch::Input>(m, "DecoderBatchInput")
161-
.def(nb::init<std::vector<std::vector<tr::ITensor::SharedConstPtr>>, tr::SizeType32>(), nb::arg("logits"),
162-
nb::arg("max_decoding_engine_tokens"))
163-
.def(nb::init<std::vector<tr::ITensor::SharedConstPtr>>(), nb::arg("logits"))
164-
.def_rw("logits", &tr::decoder_batch::Input::logits)
165-
.def_rw("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps)
166-
.def_rw("batch_slots", &tr::decoder_batch::Input::batchSlots);
167-
168156
nb::class_<tr::LookaheadDecodingBuffers>(m, "LookaheadDecodingBuffers")
169157
.def(nb::init<tr::SizeType32, tr::SizeType32, tr::BufferManager const&>(), nb::arg("max_num_sequences"),
170158
nb::arg("max_tokens_per_step"), nb::arg("buffer_manager"))
@@ -343,6 +331,10 @@ void initBindings(nb::module_& m)
343331

344332
void initBindingsEarly(nb::module_& m)
345333
{
334+
nb::class_<tr::BufferManager>(m, "BufferManager")
335+
.def(nb::init<tr::BufferManager::CudaStreamPtr, bool>(), nb::arg("stream"), nb::arg("trim_pool") = false)
336+
.def_prop_ro("stream", &tr::BufferManager::getStream);
337+
346338
nb::class_<tr::SpeculativeDecodingMode>(m, "SpeculativeDecodingMode")
347339
.def(nb::init<tr::SpeculativeDecodingMode::UnderlyingType>(), nb::arg("state"))
348340
.def_static("NoneType", &tr::SpeculativeDecodingMode::None)

0 commit comments

Comments
 (0)