diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 1a5f718fb0..4396047896 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -125,25 +125,50 @@ def __init__( ), f"context must be a DynamicInferenceContext, got {type(context)}" assert isinstance(random_seed, int), f"random_seed must be an int, got {type(random_seed)}" - self.request_counter = Counter() - self.controller = controller - self.context = context + # Initialization options. self.random_seed = random_seed self.track_paused_request_events = track_paused_request_events + self.enable_chunked_prefill = enable_chunked_prefill + + self.inference_logging_step_interval = inference_logging_step_interval + + if enable_cuda_graph is not None: + self.cuda_graph_impl = "local" if enable_cuda_graph else "none" + else: + self.cuda_graph_impl = controller.inference_wrapped_model.model.config.cuda_graph_impl + + # Objects which sit on a lower level of the abstraction stack. + self.controller = controller + self.context = context + + # Runtime state. + self.paused = False + self.stopped = False + self._loop = get_asyncio_loop() + self._cond = asyncio.Condition() + + # Coordinator state. + self.use_coordinator = False + self.is_tp0_and_pp0 = ( + parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_pipeline_model_parallel_rank() == 0 + ) + + # Request state. + self.request_counter = Counter() self.step_count = 0 self.finished_request_count = 0 + + self.requests: Dict[int, DynamicInferenceRequest] = {} self.waiting_request_ids = deque() self.failed_request_ids = [] # deque() - self.request_counter = Counter() - self.requests: Dict[int, DynamicInferenceRequest] = {} self.request_completion_futures: Dict[int, asyncio.Future] = {} + + # Timing and logging variables. self.step_start_event = torch.cuda.Event(enable_timing=True) self.step_end_event = torch.cuda.Event(enable_timing=True) - self.paused = False - self.stopped = False - self.enable_chunked_prefill = enable_chunked_prefill + self.capture_stats = None - self.inference_logging_step_interval = inference_logging_step_interval # Configure wandb to use separate step counter for inference metrics (only once) if self.inference_logging_step_interval > 0 and self.context.metrics_writer is not None: logging.info( @@ -169,19 +194,7 @@ def __init__( max_step = int(val) self.inference_step_offset = int(max_step) - # Initialize the asyncio loop if it has not already been initialized. - # TODO: Start the engine loop here. - self._loop = get_asyncio_loop() - self._cond = asyncio.Condition() - # Capture cuda graph. - self.capture_stats = None - - if enable_cuda_graph is not None: - self.cuda_graph_impl = "local" if enable_cuda_graph else "none" - else: - self.cuda_graph_impl = controller.inference_wrapped_model.model.config.cuda_graph_impl - if self.cuda_graph_impl == "local": self.create_cuda_graphs() @@ -476,6 +489,8 @@ def _add_request( if request.status != Status.FAILED: self.waiting_request_ids.append(request_id) + else: + self.failed_request_ids.append(request_id) # Create a new asyncio Future to notify the user when the request has completed. self.request_completion_futures[request_id] = self._loop.create_future() @@ -497,7 +512,6 @@ def add_request( Return: Returns an asyncio `Future[DynamicInferenceRequest]` for the user to wait on. """ - prompt_str = None # Tokenize prompt if text. if isinstance(prompt, str): @@ -720,54 +734,103 @@ def schedule_chunked_prefill(self): # chunked prefill request at the head of the waiting queue # Note that we do not need to continue check the queue, as the tokens are full - async def async_step( - self, *, verbose: Optional[bool] = False - ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]: - """ - Wrapper for controller.generate_output_tokens_dynamic_batch(), to - match vLLM API. Uses `asyncio` for continuous generation which allows this - method to sleep and wake up when new requests are available. - - Args: - sampling_params (SamplingParams): The sampling parameters. - verbose (bool): Whether to run in verbose mode. + async def async_forward(self) -> Tuple[Dict, Dict, float, int]: + """Uses `asyncio` for continuous generation. + Sleeps when no requests are available, until new requests have been added. Returns: A tuple comprised of: - 1. Requests that ran in the last step and are still active. - 2. Requests that ran in the last step and have now finished. - 3. The step time in seconds. + step_result (Optional[Dict]): The result of the step. + context_state (Dict): A tuple consisting of the state of the context. + is_decode_only, total/paused request count, active token count. + step_time (float): How long this step took. """ # schedule requests self.schedule_waiting_requests() - # Previous context state, for printing output below. - prev_is_decode_only = self.context.is_decode_only() - prev_total_request_count = self.context.total_request_count - prev_paused_request_count = self.context.paused_request_count - prev_active_token_count = self.context.active_token_count - - range_push("Prefill" if not prev_is_decode_only else "Decode") + # Saving pre-step state, for printing output below. + is_decode_only = self.context.is_decode_only() + pre_step_context_state = { + "is_decode_only": is_decode_only, + "total_request_count": self.context.total_request_count, + "paused_request_count": self.context.paused_request_count, + "active_token_count": self.context.active_token_count, + } # Generate tokens. - is_decode_only = self.context.is_decode_only() - # save the is_decode_only AFTER scheduling, BEFORE update + range_push("Prefill" if not is_decode_only else "Decode") + # TODO @tde: Remember to account for this line when overlapping forward and bookkeep. self.is_decode_only = is_decode_only + self.step_start_event.record() result = await self.controller.async_generate_output_tokens_dynamic_batch() self.step_end_event.record() self.step_end_event.synchronize() step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3 + self.step_count += 1 + + range_pop() + + if ( + self.inference_logging_step_interval > 0 + and step_count > 0 + and step_count % self.inference_logging_step_interval == 0 + and self.context.metrics_writer is not None + ): + kvcache_util_stats = self.context.get_kvcache_utilization_stats() + else: + kvcache_util_stats = None + + post_step_context_state = { + "waiting_request_count": len(self.waiting_request_ids), + "finished_request_count": self.finished_request_count, + "kv_stats": kvcache_util_stats, + "padded_active_token_count": self.context.padded_active_token_count, + "using_cuda_graph_this_step": self.context.using_cuda_graph_this_step(), + "total_active_block_count": self.context.block_allocator.active_count, + "total_paused_block_count": self.context.block_allocator.paused_count, + "total_active_used_blocks": self.context.block_allocator.get_active_used(), + "total_paused_used_blocks": self.context.block_allocator.get_paused_used(), + } + + context_state = {**pre_step_context_state, **post_step_context_state} + + return result, context_state, step_time, self.step_count + async def async_bookkeep( + self, + step_result: Optional[Dict], + context_state: Dict, + step_time: float, + step_count: int, + *, + verbose: bool = False, + ): + """Uses `asyncio` for continuous bookkeeping. + + Args: + step_result (Optional[Dict]): The result of the step. + context_state (Dict): is_decode_only, total/paused request count, active token count. + step_time (float): How long this step took. + step_count (int): The count of the step. + verbose (bool): Whether to run in verbose mode. + + Returns: + A dictionary containing: + active_requests (List): Requests that ran in the last step and are still active. + finished_requests (List): Requests that ran in the last step and have now finished. + step_time (float): The step time in seconds. + cuda_graph_request_count (int): The CUDA graph batch size matching this step. + """ # Increment finished_request_count. cuda_graph_request_count = None - if result is not None: - active_request_ids = result["active_request_ids"] - newly_paused_request_ids = result["newly_paused_request_ids"] - finished_request_ids = result["finished_request_ids"] - sample = result["sample"] - log_probs = result["log_probs"] - cuda_graph_request_count = result["cuda_graph_request_count"] + if step_result is not None: + active_request_ids = step_result["active_request_ids"] + newly_paused_request_ids = step_result["newly_paused_request_ids"] + finished_request_ids = step_result["finished_request_ids"] + sample = step_result["sample"] + log_probs = step_result["log_probs"] + cuda_graph_request_count = step_result["cuda_graph_request_count"] # Add paused events. if newly_paused_request_ids is not None and self.track_paused_request_events: @@ -795,28 +858,27 @@ async def async_step( self.request_completion_futures[failed_request_id].set_result(failed_request) self.failed_request_ids.clear() - # Log KV cache utilization stats to W&B - if ( - self.inference_logging_step_interval > 0 - and self.step_count > 0 - and self.step_count % self.inference_logging_step_interval == 0 - and self.context.metrics_writer is not None - ): - - # Get KV cache utilization stats from dynamic context - kv_stats = self.context.get_kvcache_utilization_stats() + # Handle necessary ZMQ DP coordinator communication. + if self.use_coordinator and self.is_tp0_and_pp0 and finished_requests: + payload = msgpack.packb( + [Headers.ENGINE_REPLY.value, [r.serializable() for r in finished_requests]], + use_bin_type=True, + ) + self.socket_for_receiving_requests.send(payload) + # Log KV cache utilization stats to W&B + if context_state["kv_stats"] is not None: # Prepare metrics dictionary with all stats # Use 'inference/' prefix for all metrics to separate from training metrics metrics = { - 'inference/inference_step': int(self.inference_step_offset + int(self.step_count)), + 'inference/inference_step': int(self.inference_step_offset + int(step_count)), 'inference/step_time_s': float(step_time), 'inference/waiting_queue_len': int(len(self.waiting_request_ids)), 'inference/total_requests_dict_size': int(len(self.requests)), } # Add KV stats with inference/ prefix # Convert utilization metrics from 0-1 range to 0-100 percentage range for better visualization - for key, value in kv_stats.items(): + for key, value in context_state["kv_stats"].items(): if 'utilization' in key: # Convert to percentage (0-100) and group under kvcache_utilization metrics[f'inference/{key}'] = float(value * 100.0) @@ -832,16 +894,15 @@ async def async_step( # Print context state. if verbose: - context = self.context mem = torch.cuda.memory_stats() - step_type = "decode" if is_decode_only else "non-decode" + step_type = "decode" if context_state["is_decode_only"] else "non-decode" output_str = ( "* step %d | %s ... time: %.3f%s ... " "reqs: a %d/%d, p %d/%d, w %d, f %d ... " "blocks: a %d/%d, p %d/%d ... " "mem: tensors %d, alloc %.1f gb, res %.1f gb." % ( - self.step_count, + step_count, datetime.now().strftime("%H:%M:%S"), step_time, ( @@ -850,34 +911,34 @@ async def async_step( step_type, ( "DIM %d:%d" - % (context.padded_active_token_count, prev_active_token_count) - if self.context.using_cuda_graph_this_step() + % ( + context_state["padded_active_token_count"], + context_state["active_token_count"], + ) + if context_state["using_cuda_graph_this_step"] else "OFF" ), ) ), - prev_total_request_count - prev_paused_request_count, - context.block_allocator.active_count, - prev_paused_request_count, - context.block_allocator.paused_count, - len(self.waiting_request_ids), - self.finished_request_count, - context.block_allocator.get_active_used(), - context.block_allocator.active_count, - context.block_allocator.get_paused_used(), - context.block_allocator.paused_count, + context_state["total_request_count"] - context_state["paused_request_count"], + context_state["total_active_block_count"], + context_state["paused_request_count"], + context_state["total_paused_block_count"], + context_state["waiting_request_count"], + context_state["finished_request_count"], + context_state["total_active_used_blocks"], + context_state["total_active_block_count"], + context_state["total_paused_used_blocks"], + context_state["total_paused_block_count"], mem["allocation.all.current"], mem["allocated_bytes.all.current"] / (1024**3), mem["reserved_bytes.all.current"] / (1024**3), ) ) - if prev_is_decode_only: + if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" logging.info(output_str) - self.step_count += 1 - - range_pop() return { "active_requests": active_requests, "finished_requests": finished_requests, @@ -885,14 +946,36 @@ async def async_step( "cuda_graph_request_count": cuda_graph_request_count, } + async def async_step( + self, *, verbose: bool = False + ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]: + """ + Wrapper for controller.generate_output_tokens_dynamic_batch(), to + match vLLM API. Uses `asyncio` for continuous generation which allows this + method to sleep and wake up when new requests are available. + + Args: + verbose (bool): Whether to run in verbose mode. + + Returns: + A tuple comprised of: + 1. Requests that ran in the last step and are still active. + 2. Requests that ran in the last step and have now finished. + 3. The step time in seconds. + """ + last_step_data = await self.async_forward() + ret = await self.async_bookkeep(*last_step_data, verbose=verbose) + # Keep for compatibility with current test suite. + return ret + def step_modern( - self, *, verbose: Optional[bool] = False + self, *, verbose: bool = False ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]: """Synchronous wrapper for `self.async_step`.""" return self._loop.run_until_complete(self.async_step(verbose=verbose)) def step_legacy( - self, sampling_params: SamplingParams, *, verbose: Optional[bool] = False + self, sampling_params: SamplingParams, *, verbose: bool = False ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest], float]: """Synchronous wrapper for `self.async_step`.""" warnings.warn( @@ -900,9 +983,7 @@ def step_legacy( "0.16. Please use `step_modern()` going forward, which will eventually " "be renamed to `step()`." ) - result = self._loop.run_until_complete( - self.async_step(sampling_params=sampling_params, verbose=verbose) - ) + result = self._loop.run_until_complete(self.async_step(verbose=verbose)) return (result["active_requests"], result["finished_requests"], result["step_time"]) # For backwards compatibility, point `step()` to `step_legacy()`. Starting in @@ -1048,6 +1129,7 @@ async def run_engine_with_coordinator( ): """Continually steps the engine asynchronously.""" self._loop = get_asyncio_loop(loop) + self.use_coordinator = True try: while True: self.schedule_requests() @@ -1078,25 +1160,7 @@ async def run_engine_with_coordinator( await asyncio.sleep(0.02) continue - engine_output = await self.async_step(verbose=verbose) - - is_tp0_and_pp0 = ( - parallel_state.get_tensor_model_parallel_rank() == 0 - and parallel_state.get_pipeline_model_parallel_rank() == 0 - ) - if ( - is_tp0_and_pp0 - and engine_output is not None - and engine_output["finished_requests"] - ): - payload = msgpack.packb( - [ - Headers.ENGINE_REPLY.value, - [r.serializable() for r in engine_output["finished_requests"]], - ], - use_bin_type=True, - ) - self.socket_for_receiving_requests.send(payload) + await self.async_step(verbose=verbose) except asyncio.CancelledError: pass