-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-6756][feat] Add Beam Search to TorchSampler #8509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #21928 [ run ] triggered by Bot. Commit: |
|
PR_Github #21928 [ run ] completed with state |
80f31f7 to
990321c
Compare
|
/bot run |
|
PR_Github #22046 [ run ] triggered by Bot. Commit: |
|
PR_Github #22046 [ run ] completed with state |
ixlmar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nit, stronger feelings only about where it touches the batched sampling, also in anticipation of future work (currently mainly #8581).
1c7b4be to
26fac04
Compare
|
/bot run |
|
PR_Github #22787 [ run ] triggered by Bot. Commit: |
|
PR_Github #22787 [ run ] completed with state |
f754360 to
9be3a85
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #23043 [ run ] triggered by Bot. Commit: |
|
PR_Github #23043 [ run ] completed with state |
9be3a85 to
eb265ac
Compare
|
/bot run --disable-fail-fast |
eb265ac to
f7c63c3
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #24210 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #25701 [ run ] triggered by Bot. Commit: |
8dfa36d to
6b4f97e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #25759 [ run ] triggered by Bot. Commit: |
|
PR_Github #25701 [ run ] completed with state |
|
PR_Github #25759 [ run ] completed with state |
6b4f97e to
326305d
Compare
…ing. - Added BeamSearchArgs class and updated methods to handle beam search logic, including cache indirection updates and beam score management. - Modified create_torch_sampler_args to include use_overlap_scheduler parameter. - Updated sampling strategy to accommodate beam search requests. Signed-off-by: Stefan Niebler <[email protected]>
Signed-off-by: Stefan Niebler <[email protected]>
Unsqueeze buffer returned from sampling to always contain the beam_width dimension Signed-off-by: Stefan Niebler <[email protected]>
…features - Add metadata object to grouped_request to pass additional data, which is not part of the SamplingStrategy definition. - Add several buffer to TorchSampler Store for beam search features, which are only allocated when beam search is used - Add support for beam search with streaming enabled - Beam search no longer requires all beams to finish at the same iteration. - gather_generation_logits can now be used together with beam search. - Logprob generation is now possible with beam search enabled. Top-k logprobs is not supported - Updated test_beam_search.py to also cover TorchSampler - General changes for formatting and readability Signed-off-by: Stefan Niebler <[email protected]>
- Updated create_torch_sampler_args and related methods to replace use_overlap_scheduler with disable_overlap_scheduler. - Added missing disable_overlap_scheduler parameter to TorchSampler.Args in ad_executor.py - Introduced BeamHistory data class to encapsulate beam search history, including tokens and logprobs. - Refactored methods to create and finalize beam history, improving clarity and functionality. Signed-off-by: Stefan Niebler <[email protected]>
…onality - Introduced a new test file `test_beam_search_util.py` containing a dummy model and utility functions for beam search testing. - Refactored existing tests in `test_beam_search.py` to utilize the new utility functions, improving test organization and clarity. - Added comprehensive tests for beam search sampling, including validation of output shapes, cache indirection, and cumulative log probabilities. - Added new unit tests for beam search sampling, updating of beams and finalization of requests Signed-off-by: Stefan Niebler <[email protected]>
- Updated beam search sampling logic to ensure asynchronous handling of finished beams and cache indirection Signed-off-by: Stefan Niebler <[email protected]>
…e_async - Added new buffer to handle Beam Search sampling asynchronously, where possible - split beam history creation and beam finalization. Beam finalization occurs in update_requests - adjusted write_finish_reason to support beam search Signed-off-by: Stefan Niebler <[email protected]>
… the updated beam search in TorchSampler - enhanced logprob testing to verify sum(logprobs) == cum_log_probs - added testing for stop tokens Signed-off-by: Stefan Niebler <[email protected]>
Signed-off-by: Stefan Niebler <[email protected]>
…h sampling - Fixed several bugs, which caused non-beam search testcases to fail - Adjusted test_torch_sampler.py to conform with changes in the TorchSampler - Improved Beam search sampling with async torch operations Signed-off-by: Stefan Niebler <[email protected]>
Signed-off-by: Stefan Niebler <[email protected]>
- Converted RequestGroupKey from NamedTuple to a dataclass with frozen and kw_only attributes. - Added __iter__ and __len__ methods for improved usability. Signed-off-by: Stefan Niebler <[email protected]>
…Sampler - Modified the _handle_finish_reasons method to accept an additional parameter, finish_reasons_list, for correct handling of finish reasons in beam search. - Updated calls to _handle_finish_reasons throughout the TorchSampler class to accommodate the new parameter. Signed-off-by: Stefan Niebler <[email protected]>
… method and fix bugs - Introduced setup_sampler_step method to enable the setup process for disaggregated serving in beam search. - Updated cache indirection initialization to use torch.zeros to prevent reading invalid values from cache_indirection - Updated mtpSampler to correctly call TorchSampler functions - Fixed handle_finish_reasons by wrapping finish reasons in the FinishReason class. - Adjusted max_lengths_tensor calculation to account for original prompt length. Signed-off-by: Stefan Niebler <[email protected]>
…and sampling utilities - Introduced new functions to retrieve beam width parameters for input and output, improving clarity and modularity. - Updated UtilsSamplingParams to include separate beam width parameters and a flag for beam search usage. - Refactored beam search sampling logic to accommodate changes in beam width handling, ensuring compatibility with new parameters. - Unified beam search sampling for context and generation requests - Simplified code for beam history creation - Adjusted test cases to reflect changes in beam width handling and improved logprob validation. Signed-off-by: Stefan Niebler <[email protected]>
… merge - Modified `model_kwargs` to include `sampler_type` for improved test configuration. - Adjusted `llm_cuda_graph` fixture to remove unnecessary `sampler_type` parameter. - Enhanced clarity in `test_torch_sampler.py` by adding comments regarding the `is_context_init_state` attribute. Signed-off-by: Stefan Niebler <[email protected]>
326305d to
224da51
Compare
|
/bot run |
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.