diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 6ca8a73f..e35b86d4 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -17,6 +17,8 @@ Union, ) +import os + import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_reduce @@ -640,6 +642,7 @@ def stream_generate( prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + max_captured_steps: int = 5, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -691,6 +694,18 @@ def stream_generate( token_generator = speculative_generate_step( prompt, model, draft_model, **kwargs ) + + # catpure prfoile + doProfile = os.environ.get("MTL_CAPTURE_ENABLED", "0") == "1" + includePrefillStage = os.environ.get("MLX_PROFILE_PREFILL", "0") == "1" + if doProfile: + max_captured_steps = int(os.environ.get("MLX_MAX_CAPTURED_STEPS", max_captured_steps)) + mlx_trace_file = os.environ.get("MLX_TRACE_FILE", "mlx_trace.gputrace") + + # NOTE (yiakwy) : profile prefill stage is very expensive for MLX + if includePrefillStage: + mx.metal.start_capture(mlx_trace_file) + with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): @@ -698,6 +713,14 @@ def stream_generate( prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() + + if not includePrefillStage: + mx.metal.start_capture(mlx_trace_file) + + if n >= max_captured_steps: + mx.metal.stop_capture() + max_captured_steps = float("inf") # only stop once + if token in tokenizer.eos_token_ids: break