Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_inference_context(
else None
),
block_size_tokens=args.inference_dynamic_batching_block_size,
active_buffer_size_gb=args.inference_dynamic_batching_active_buffer_size_gb,
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
max_tokens=args.inference_dynamic_batching_max_tokens,
tensor_model_parallel_size=args.tensor_model_parallel_size,
materialize_only_last_token_logits=not args.return_log_probs,
Expand Down
4 changes: 2 additions & 2 deletions examples/inference/gpt/gpt_dynamic_inference_12b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
: ${INCOMING_REQUESTS_PER_SEC=100.}

# Dynamic context.
: ${ACTIVE_BUFFER_SIZE_GB=50.}
: ${BUFFER_SIZE_GB=50.}

# Cuda graphs.
: ${CUDA_GRAPH_IMPL=local}
Expand Down Expand Up @@ -76,7 +76,7 @@ ARGS=" \
--inference-rng-tracker \
\
--inference-dynamic-batching \
--inference-dynamic-batching-active-buffer-size-gb ${ACTIVE_BUFFER_SIZE_GB} \
--inference-dynamic-batching-buffer-size-gb ${BUFFER_SIZE_GB} \
\
${EXTRA_ARGS} \
"
Expand Down
4 changes: 2 additions & 2 deletions examples/inference/gpt/gpt_dynamic_inference_357m.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
: ${INCOMING_REQUESTS_PER_SEC=100.}

# Dynamic context.
: ${ACTIVE_BUFFER_SIZE_GB=50.}
: ${BUFFER_SIZE_GB=50.}

# Cuda graphs.
: ${CUDA_GRAPH_IMPL=local}
Expand Down Expand Up @@ -62,7 +62,7 @@ ARGS=" \
--inference-rng-tracker \
\
--inference-dynamic-batching \
--inference-dynamic-batching-active-buffer-size-gb ${ACTIVE_BUFFER_SIZE_GB} \
--inference-dynamic-batching-buffer-size-gb ${BUFFER_SIZE_GB} \
\
${EXTRA_ARGS} \
"
Expand Down
2 changes: 1 addition & 1 deletion examples/inference/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def build_dynamic_engine_setup_prefix(

# Buffer limits config
buffer_limits_str = (
f"bf: {get_mem_size_str(args.inference_dynamic_batching_active_buffer_size_gb*1024**3)}, "
f"bf: {get_mem_size_str(args.inference_dynamic_batching_buffer_size_gb*1024**3)}, "
f"{context.block_allocator.active_count} chunks "
f"[r {context.max_active_requests}, t {context.max_tokens}]"
)
Expand Down
4 changes: 2 additions & 2 deletions megatron/core/inference/contexts/dynamic_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class BlockAllocator:
space for paused requests that live on the CPU.
"""

def __init__(self, context: "DynamicInferenceContext", active_count: int):
def __init__(self, context: "DynamicInferenceContext", total_count: int):

self.context = context

active_count -= 1 # -1 for dummy_block_idx (see below)
active_count = (total_count - 1) // 2 # -1 for dummy_block_idx (see below)
active_count = max(1, active_count) # need at least one block
self.total_count = 2 * active_count + 1 # +1 for dummy_block_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here self.total_count may not equal to the input total_count.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's correct. total_count needs to be rounded down to an odd number so that we have one extra block index for dummy_block_idx. That was the case prior to this PR as well when considering active_count, which wasn't necessarily the same as its input value.

self.total_avail = self.total_count - 1 # -1 for dummy_block_idx
Expand Down
76 changes: 34 additions & 42 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,37 +195,23 @@ class DynamicInferenceContext(BaseInferenceContext):
arbitrary sequence length may be added, paused, or removed from the context
at any step. The only constraint is the maximum number of requests or tokens
that the context is defined to support. For the block-level KV cache, a memory
buffer is allocated up front (size `2 * active_buffer_size_gb`), that is
buffer is allocated up front (size `buffer_size_gb` if `unified_memory_level`
== 0, or `2 * buffer_size_gb` if `unified_memory_level` == 1), that is
divided into blocks and dynamically assigned to requests. At any given step,
any unassigned blocks equate to unused space.

Additionally, a fraction of the memory buffer (`gtd_request_fraction`, i.e.,
the 'guaranteed' request fraction) is reserved for guaranteeing that a
minimum number of active requests may continue to generate tokens on any step.
The reason for this is that the context manages two pools of requests: 1)
active requests, and 2) paused requests. Paused requests are requests where
insufficient memory blocks remain for future assignment, and these requests
are set aside until enough memory blocks are available. Active requests are
requests that have sufficient memory blocks to proceed with their generations.

The situation can arise where all requests eventually become paused due to all
memory blocks being assigned. In this case, there are no active requests and
thus no progress can be made. To handle this case, a fraction of the memory
buffer is reserved that only allows active requests, and no paused requests.
This fraction must be carefully tuned, as it can have an order of magnitude
impact on overall latency.

Args:
params_dtype (torch.dtype): Dtype used for KV cache.
num_layers (int): Number of layers.
kv_channels (int): Hidden dimension per attention head.
num_attention_heads (int): Number of attention heads.
max_sequence_length (int): Max possible sequence length (prompt + output)
that will occur.
active_buffer_size_gb (float): Buffer size reserved for active requests
that live on the GPU. The total buffer size (stored in unified memory)
is 2x this value, with the the other half of the buffer reserved for
paused requests that live on the CPU.
buffer_size_gb (float): Buffer size reserved on the GPU for the KV cache.
if `unified_memory_level` >= 1, then CPU memory is additionally
utilized, resulting in a total buffer size of `2 * buffer_size_gb`.
Regardless of total buffer size, the KV cache is conceptually divided
into 50% active requests and 50% paused requests.
max_tokens (int): Max number of tokens to use for forward passes. This is
primarily limited by prefill activation memory usage. (Defaults to
16384).
Expand Down Expand Up @@ -272,7 +258,7 @@ def __init__(
kv_channels: int,
num_attention_heads: int,
max_sequence_length: int,
active_buffer_size_gb: float,
buffer_size_gb: float,
max_tokens: int = DEFAULT_MAX_TOKENS,
block_size_tokens: int = 256,
tensor_model_parallel_size: Optional[int] = None,
Expand Down Expand Up @@ -371,14 +357,29 @@ def __init__(
mamba_states_memory_per_request *= self.num_mamba_layers
mamba_states_memory_per_request *= dtype_size_bytes

# Unified memory.
self.unified_memory_level = unified_memory_level
if unified_memory_level > 0:
try:
self.unified_memory_mempool = create_unified_mempool()
except UnifiedMemoryUnsupportedError:
if torch.distributed.get_rank() == 0:
warnings.warn(
"Unified memory requested but not available; defaulting to GPU memory."
)
self.unified_memory_level = 0

# Initialize block allocator.
active_buffer_size_bytes = int(active_buffer_size_gb * 1024**3)
active_block_count_total = active_buffer_size_bytes // (
buffer_size_bytes = int(buffer_size_gb * 1024**3)
block_count_total = buffer_size_bytes // (
self.block_size_bytes + mamba_states_memory_per_request
)
self.block_allocator = BlockAllocator(context=self, active_count=active_block_count_total)
del active_block_count_total # use self.block_allocator.active_count
active_buffer_size_bytes = self.block_allocator.active_count * self.block_size_bytes
self.block_allocator = BlockAllocator(
context=self,
total_count=(
block_count_total if self.unified_memory_level == 0 else 2 * block_count_total
),
)

# Set max_total_requests, max_active_requests, max_tokens.
self.max_total_requests = self.block_allocator.total_count - 1 # -1 for dummy block
Expand All @@ -395,18 +396,6 @@ def __init__(
self.params_dtype = params_dtype
self.max_sequence_length = max_sequence_length

# Unified memory.
self.unified_memory_level = unified_memory_level
if unified_memory_level > 0:
try:
self.unified_memory_mempool = create_unified_mempool()
except UnifiedMemoryUnsupportedError:
if torch.distributed.get_rank() == 0:
warnings.warn(
"Unified memory requested but not available; defaulting to GPU memory."
)
self.unified_memory_level = 0

# Request and token counts.
self.total_request_count = 0
self.active_token_count = 0
Expand Down Expand Up @@ -606,7 +595,10 @@ def allocate_mamba_states():
# Print info.
logging.info(
"DynamicInferenceContext: allocated context with active buffer size %s (%d blocks)."
% (get_mem_size_str(active_buffer_size_bytes), self.block_allocator.active_count)
% (
get_mem_size_str(self.block_allocator.active_count * self.block_size_bytes),
self.block_allocator.active_count,
)
)

@classmethod
Expand Down Expand Up @@ -636,7 +628,7 @@ def from_config(
inference_config: InferenceWrapperConfig,
model,
max_batch_size: int,
active_buffer_size_gb: float = 40,
buffer_size_gb: float = 40,
num_cuda_graphs: int = None,
):
"""
Expand All @@ -655,7 +647,7 @@ def from_config(
kv_channels=model_config.kv_channels,
num_attention_heads=model_config.num_query_groups,
max_sequence_length=inference_config.inference_max_seq_length,
active_buffer_size_gb=active_buffer_size_gb,
buffer_size_gb=buffer_size_gb,
materialize_only_last_token_logits=False,
num_cuda_graphs=num_cuda_graphs,
use_flashinfer_fused_rope=None,
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/inference/engines/static_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
inference_config=inference_wrapper_config,
model=text_generation_controller.inference_wrapped_model.model,
max_batch_size=max_batch_size,
active_buffer_size_gb=buffer_size_gb,
buffer_size_gb=buffer_size_gb,
num_cuda_graphs=1,
)
self.controller.inference_wrapped_model.inference_context = dynamic_context
Expand Down
17 changes: 11 additions & 6 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ def validate_args(args, defaults={}):
), "Pipeline-parallel microbatched inference is incompatible with CUDA graphs"

if args.inference_dynamic_batching:
assert args.inference_dynamic_batching_active_buffer_size_gb is not None
assert args.inference_dynamic_batching_buffer_size_gb is not None
assert args.inference_dynamic_batching_block_size % 256 == 0, "block size should be a multiple of 256"

# MoE upcycling check
Expand Down Expand Up @@ -1442,12 +1442,17 @@ def _add_inference_args(parser):
group.add_argument('--inference-dynamic-batching',
action='store_true', default=False,
help='Enable dynamic batching mode.')
group.add_argument('--inference-dynamic-batching-active-buffer-size-gb',
group.add_argument('--inference-dynamic-batching-buffer-size-gb',
type=float, default=40.,
help='Buffer size (GB) allocated for the active (on-GPU) '
'portion of the chunked KV memory. The total buffer size '
'is 2x this value, which includes the same-size on-CPU '
'paused buffer.')
help='Amount of on-GPU memory allocated for the KV cache. '
'The total amount of memory allocated for the KV cache '
'(CPU + GPU memory) depends on the value set for the '
'unified virtual memory (UVM) level (via '
'`--inference-dynamic-batching-unified-memory-level`).'
'If the UVM level is 0, then only GPU memory is used and '
'the total memory equals `buffer_size_gb`. If the UVM '
'level is 1, then additional memory is utilized on the '
'CPU and the total memory equals `2 * buffer_size_gb`.')
group.add_argument('--inference-dynamic-batching-block-size',
type=int, default=256,
help='KV cache block size. '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ MODEL_ARGS:
--return-log-probs: true
--num-tokens-to-generate: 30
--enable-cuda-graph: true
--inference-dynamic-batching-active-buffer-size-gb: 20
--inference-dynamic-batching-buffer-size-gb: 20
--dist-ckpt-strictness: log_unexpected
--inference-ckpt-non-strict: true # To handle the extra_state errors
--output-path: ${TENSORBOARD_PATH}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ MODEL_ARGS:
--num-tokens-to-generate: 30
--enable-cuda-graph: true
--decode-only-cuda-graphs: true
--inference-dynamic-batching-active-buffer-size-gb: 20
--inference-dynamic-batching-buffer-size-gb: 20
--dist-ckpt-strictness: log_unexpected
--inference-ckpt-non-strict: true # To handle the extra_state errors
--output-path: ${TENSORBOARD_PATH}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MODEL_ARGS:
--top_k: 1
--return-log-probs: true
--num-tokens-to-generate: 30
--inference-dynamic-batching-active-buffer-size-gb: 20
--inference-dynamic-batching-buffer-size-gb: 20
--dist-ckpt-strictness: log_unexpected
--inference-ckpt-non-strict: true # To handle the extra_state errors
--output-path: ${TENSORBOARD_PATH}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ MODEL_ARGS:
--top_k: 1
--return-log-probs: true
--num-tokens-to-generate: 30
--inference-dynamic-batching-active-buffer-size-gb: 10
--inference-dynamic-batching-buffer-size-gb: 10
--dist-ckpt-strictness: log_unexpected
--inference-ckpt-non-strict: true # To handle the extra_state errors
--output-path: ${TENSORBOARD_PATH}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ MODEL_ARGS:
--prompts: "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies."
--incoming-requests-per-sec: -1
--inference-repeat-n: 8
--inference-dynamic-batching-active-buffer-size-gb: 20
--inference-dynamic-batching-buffer-size-gb: 20
METRICS:
- "generated_tokens"
- "logprobs"
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ MODEL_ARGS:
--prompts: "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies."
--incoming-requests-per-sec: -1 # all requests arrive up front.
--inference-repeat-n: 8
--inference-dynamic-batching-active-buffer-size-gb: 20
--inference-dynamic-batching-buffer-size-gb: 20
METRICS:
- "generated_tokens"
- "logprobs"
Loading
Loading