-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
This RFC the following plan for how to best optimize the dynamic inference step function.
There are several interconnected issues at play:
- Dynamic sampling code is currently very unoptimized. There is a PR draft that reimplements it.
async_generate_output_tokens_dynamic_batchmixes CPU and GPU operations indiscriminately.async_generate_output_tokens_dynamic_batchmay be declared async, but it has no good way of yielding the event loop. A lot of CPU time is wasted waiting for the GPU, and can be reclaimed.
The ideal solution appears to be:
- Fix dynamic sampling code.
- Clearly separate CPU and GPU operations.
- Provide a place to yield the event loop.
The PR series suggested by this RFC are:
- Break
async_generate_output_tokens_dynamic_batchapart into multiple sub-methods, which are clearly labeled as "CPU compute" vs "GPU compute".- Achieved by Clean up dynamic inference step #1992.
- Implement barebones unoptimized dynamic sampling code.
- Tensorize the dynamic sampling bookkeeping.
- Achieved by Tensorize dynamic inference mixed sampling #2105.
- Reorganize the step function to allow for async step
- Achieved by Break apart dynamic inference step into 2 methods #2192. No new functionality, just moving code around.
- Reorder the sub-methods from point 1) so that CPU/GPU compute forms separate continuous blocks of code, and yield the event loop after the CPU compute via torch polling.
- Due to all the prep work, this will be a tiny, extremely readable, PR.
- Achieved by Make the dynamic engine step async #2193.
- Optimize dynamic sampling code via graphed FlashInfer sampling.
- Refactor dynamic logprobs computation to follow the same style as the new sampling code.
- A draft has been written by @tdene.
- Reconcile with main's implementation of
top_n_logprobs. - Wait for a torch update, or brainstorm a way to yield the event loop without polling in the current version of pytorch.
- Maybe by sampling on a single rank, instead of the current sampling on every rank?
- Will discuss further in comments.
Metadata
Metadata
Labels
enhancementNew feature or requestNew feature or request