@@ -384,39 +384,6 @@ void initBindings(nb::module_& m)
384384 .def (nb::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
385385 nb::arg (" max_num_sequences" ), nb::arg (" model_config" ), nb::arg (" world_config" ), nb::arg (" buffer_manager" ));
386386
387- nb::class_<tb::DecoderInputBuffers>(m, " DecoderInputBuffers" )
388- .def (nb::init<runtime::SizeType32, runtime::SizeType32, tr::BufferManager>(), nb::arg (" max_batch_size" ),
389- nb::arg (" max_tokens_per_engine_step" ), nb::arg (" manager" ))
390- .def_rw (" setup_batch_slots" , &tb::DecoderInputBuffers::setupBatchSlots)
391- .def_rw (" setup_batch_slots_device" , &tb::DecoderInputBuffers::setupBatchSlotsDevice)
392- .def_rw (" fill_values" , &tb::DecoderInputBuffers::fillValues)
393- .def_rw (" fill_values_device" , &tb::DecoderInputBuffers::fillValuesDevice)
394- .def_rw (" inputs_ids" , &tb::DecoderInputBuffers::inputsIds)
395- .def_rw (" forward_batch_slots" , &tb::DecoderInputBuffers::forwardBatchSlots)
396- .def_rw (" logits" , &tb::DecoderInputBuffers::logits)
397- .def_rw (" decoder_requests" , &tb::DecoderInputBuffers::decoderRequests);
398-
399- nb::class_<tb::DecoderOutputBuffers>(m, " DecoderOutputBuffers" )
400- .def_rw (" sequence_lengths_host" , &tb::DecoderOutputBuffers::sequenceLengthsHost)
401- .def_rw (" finished_sum_host" , &tb::DecoderOutputBuffers::finishedSumHost)
402- .def_prop_ro (" new_output_tokens_host" ,
403- [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor (self.newOutputTokensHost ); })
404- .def_rw (" cum_log_probs_host" , &tb::DecoderOutputBuffers::cumLogProbsHost)
405- .def_rw (" log_probs_host" , &tb::DecoderOutputBuffers::logProbsHost)
406- .def_rw (" finish_reasons_host" , &tb::DecoderOutputBuffers::finishReasonsHost);
407-
408- nb::class_<tb::SlotDecoderBuffers>(m, " SlotDecoderBuffers" )
409- .def (nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const &>(),
410- nb::arg (" max_beam_width" ), nb::arg (" max_seq_len" ), nb::arg (" buffer_manager" ))
411- .def_rw (" output_ids" , &tb::SlotDecoderBuffers::outputIds)
412- .def_rw (" output_ids_host" , &tb::SlotDecoderBuffers::outputIdsHost)
413- .def_rw (" sequence_lengths_host" , &tb::SlotDecoderBuffers::sequenceLengthsHost)
414- .def_rw (" cum_log_probs" , &tb::SlotDecoderBuffers::cumLogProbs)
415- .def_rw (" cum_log_probs_host" , &tb::SlotDecoderBuffers::cumLogProbsHost)
416- .def_rw (" log_probs" , &tb::SlotDecoderBuffers::logProbs)
417- .def_rw (" log_probs_host" , &tb::SlotDecoderBuffers::logProbsHost)
418- .def_rw (" finish_reasons_host" , &tb::SlotDecoderBuffers::finishReasonsHost);
419-
420387 m.def (
421388 " add_new_tokens_to_requests" ,
422389 [](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
@@ -435,10 +402,10 @@ void initBindings(nb::module_& m)
435402
436403 m.def (
437404 " make_decoding_batch_input" ,
438- [](std::vector<std::shared_ptr<tb::LlmRequest>>& contextRequests ,
439- std::vector<std::shared_ptr<tb::LlmRequest>>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth ,
440- std::vector<int > const & numContextLogitsPrefixSum, tb::DecoderInputBuffers const & decoderInputBuffers ,
441- runtime::decoder::DecoderState& decoderState , tr::BufferManager const & manager)
405+ [](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState ,
406+ std::vector<std::shared_ptr<tb::LlmRequest>> const & contextRequests ,
407+ std::vector<std::shared_ptr<tb::LlmRequest>> const & genRequests, tr::ITensor::SharedPtr const & logits ,
408+ int beamWidth, std::vector< int > const & numContextLogitsPrefixSum , tr::BufferManager const & manager)
442409 {
443410 std::vector<int > activeSlots;
444411 std::vector<int > generationSteps;
@@ -496,21 +463,18 @@ void initBindings(nb::module_& m)
496463 batchSlotsRange[i] = activeSlots[i];
497464 }
498465
499- auto decodingInput = std::make_unique<tr::decoder_batch::Input>(logitsVec, 1 );
500- decodingInput->batchSlots = batchSlots;
466+ decoderInputBuffers.batchLogits = logitsVec;
501467
502468 auto const maxBeamWidth = decoderState.getMaxBeamWidth ();
503469 if (maxBeamWidth > 1 )
504470 {
505471 // For Variable-Beam-Width-Search
506472 decoderState.getJointDecodingInput ().generationSteps = generationSteps;
507473 }
508-
509- return decodingInput;
510474 },
511- nb::arg (" context_requests " ), nb::arg (" generation_requests " ), nb::arg (" logits " ), nb::arg ( " beam_width " ),
512- nb::arg (" num_context_logits_prefix_sum " ), nb::arg (" decoder_input_buffers " ), nb::arg (" decoder_state " ),
513- nb::arg (" buffer_manager" ), " Make decoding batch input." );
475+ nb::arg (" decoder_input_buffers " ), nb::arg (" decoder_state " ), nb::arg (" context_requests " ),
476+ nb::arg (" generation_requests " ), nb::arg (" logits " ), nb::arg (" beam_width " ),
477+ nb::arg (" num_context_logits_prefix_sum " ), nb::arg ( " buffer_manager" ), " Make decoding batch input." );
514478}
515479
516480} // namespace tensorrt_llm::nanobind::batch_manager
0 commit comments