diff --git a/mlx_lm/chat.py b/mlx_lm/chat.py index d22673f9..140efce6 100644 --- a/mlx_lm/chat.py +++ b/mlx_lm/chat.py @@ -7,7 +7,7 @@ from .generate import stream_generate from .models.cache import make_prompt_cache from .sample_utils import make_sampler -from .utils import load +from .utils import does_model_support_prompt_cache, load DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 @@ -16,6 +16,9 @@ DEFAULT_SEED = None DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_BLOCK_LENGTH = 32 +DEFAULT_STEPS = 32 +DEFAULT_THRESHOLD = 0.95 def setup_arg_parser(): @@ -79,6 +82,24 @@ def setup_arg_parser(): default=None, help="System prompt to be used for the chat template", ) + parser.add_argument( + "--block-length", + type=int, + default=DEFAULT_BLOCK_LENGTH, + help="[Diffusion models only] Number of tokens per block", + ) + parser.add_argument( + "--steps", + type=int, + default=DEFAULT_STEPS, + help="[Diffusion models only] Number of denoising iterations per block", + ) + parser.add_argument( + "--threshold", + type=float, + default=DEFAULT_THRESHOLD, + help="[Diffusion models only] Confidence threshold for token acceptance", + ) return parser @@ -97,36 +118,43 @@ def main(): }, ) + use_cache = does_model_support_prompt_cache(model) + def print_help(): print("The command list:") print("- 'q' to exit") print("- 'r' to reset the chat") print("- 'h' to display these commands") + def reset_conversation(): + """Reset conversation history and prompt cache.""" + cache = make_prompt_cache(model, args.max_kv_size) if use_cache else None + msgs = [] + if args.system_prompt is not None: + msgs.append({"role": "system", "content": args.system_prompt}) + return cache, msgs + print(f"[INFO] Starting chat session with {args.model}.") print_help() - prompt_cache = make_prompt_cache(model, args.max_kv_size) + prompt_cache, messages = reset_conversation() + while True: query = input(">> ") if query == "q": break if query == "r": - prompt_cache = make_prompt_cache(model, args.max_kv_size) + prompt_cache, messages = reset_conversation() continue if query == "h": print_help() continue - messages = [] - if args.system_prompt is not None: - messages.append({"role": "system", "content": args.system_prompt}) + messages.append({"role": "user", "content": query}) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - for response in stream_generate( - model, - tokenizer, - prompt, - max_tokens=args.max_tokens, - sampler=make_sampler( + + gen_kwargs = { + "max_tokens": args.max_tokens, + "sampler": make_sampler( args.temp, args.top_p, xtc_threshold=args.xtc_threshold, @@ -135,11 +163,20 @@ def print_help(): tokenizer.encode("\n") + list(tokenizer.eos_token_ids) ), ), - prompt_cache=prompt_cache, - ): + "prompt_cache": prompt_cache, + "block_length": args.block_length, + "steps": args.steps, + "threshold": args.threshold, + } + + assistant_response = "" + for response in stream_generate(model, tokenizer, prompt, **gen_kwargs): print(response.text, flush=True, end="") + assistant_response += response.text print() + messages.append({"role": "assistant", "content": assistant_response}) + if __name__ == "__main__": print( diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 293de863..df0a7bcd 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -50,6 +50,9 @@ DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 +DEFAULT_BLOCK_LENGTH = 32 +DEFAULT_STEPS = 32 +DEFAULT_THRESHOLD = 0.95 def str2bool(string): @@ -210,6 +213,24 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--block-length", + type=int, + default=DEFAULT_BLOCK_LENGTH, + help="[Diffusion models only] Number of tokens per block", + ) + parser.add_argument( + "--steps", + type=int, + default=DEFAULT_STEPS, + help="[Diffusion models only] Number of denoising iterations per block", + ) + parser.add_argument( + "--threshold", + type=float, + default=DEFAULT_THRESHOLD, + help="[Diffusion models only] Confidence threshold for token acceptance", + ) return parser @@ -678,60 +699,137 @@ def stream_generate( detokenizer = tokenizer.detokenizer + complete_eos_token_ids = set(tokenizer.eos_token_ids) + kwargs["max_tokens"] = max_tokens - if draft_model is None: - kwargs.pop("num_draft_tokens", None) - token_generator = generate_step(prompt, model, **kwargs) - # from_draft always false for non-speculative generation - token_generator = ( - (token, logprobs, False) for token, logprobs in token_generator - ) + if callable(getattr(model, "generate_step", None)) and not draft_model: + # Build EOS token set + if complete_eos_token_ids: + if hasattr(model, "args") and hasattr(model.args, "eos_token_id"): + complete_eos_token_ids.add(model.args.eos_token_id) + if hasattr(model, "EXTRA_EOS_TOKENS"): + for token in model.EXTRA_EOS_TOKENS: + try: + token_ids = tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) == 1: + complete_eos_token_ids.add(token_ids[0]) + except Exception: + pass + + # Build kwargs for custom generate_step + custom_kwargs = { + "max_tokens": kwargs.get("max_tokens", max_tokens), + "sampler": kwargs.get("sampler"), + "eos_token_ids": ( + list(complete_eos_token_ids) if complete_eos_token_ids else None + ), + } + # Add diffusion-specific CLI parameters if present + for key in ["block_length", "steps", "threshold"]: + if key in kwargs: + custom_kwargs[key] = kwargs[key] + + batched_prompt = prompt[None] if len(prompt.shape) == 1 else prompt + + def wrap_custom_generator(): + with mx.stream(generation_stream): + for i, (tokens, logprobs) in enumerate( + model.generate_step(inputs=batched_prompt, **custom_kwargs) + ): + tokens_list = ( + tokens.flatten().tolist() + if isinstance(tokens, mx.array) + else list(tokens) + ) + yield (tokens_list, logprobs, False) + if i % 256 == 0: + mx.clear_cache() + + token_generator = wrap_custom_generator() else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) + for key in ["block_length", "steps", "threshold"]: + kwargs.pop(key, None) + + if draft_model is None: + kwargs.pop("num_draft_tokens", None) + token_generator = ( + ([token], logprobs, False) + for token, logprobs in generate_step(prompt, model, **kwargs) + ) + else: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = ( + ([token], logprobs, from_draft) + for token, logprobs, from_draft in speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + ) + with wired_limit(model, [generation_stream]): tic = time.perf_counter() - for n, (token, logprobs, from_draft) in enumerate(token_generator): + total_tokens = 0 + prompt_tps = 0.0 + last_token = -1 + last_logprobs = mx.array([]) + last_from_draft = False + finish_reason = None + + for n, (tokens_list, logprobs, from_draft) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() - if token in tokenizer.eos_token_ids: - break - detokenizer.add_token(token) - if (n + 1) == max_tokens: - break + for token in tokens_list: + if token in complete_eos_token_ids: + finish_reason = "stop" + break + + detokenizer.add_token(token) + total_tokens += 1 + last_token = token + last_logprobs = logprobs + last_from_draft = from_draft + if total_tokens >= max_tokens: + finish_reason = "length" + break + + elapsed = time.perf_counter() - tic yield GenerationResponse( text=detokenizer.last_segment, - token=token, - logprobs=logprobs, + token=last_token, + logprobs=last_logprobs, from_draft=from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, - generation_tokens=n + 1, - generation_tps=(n + 1) / (time.perf_counter() - tic), + generation_tokens=total_tokens, + generation_tps=total_tokens / elapsed if elapsed > 0 else 0.0, peak_memory=mx.get_peak_memory() / 1e9, finish_reason=None, ) + if finish_reason: + break + + if not finish_reason: + finish_reason = "length" + detokenizer.finalize() + elapsed = time.perf_counter() - tic yield GenerationResponse( text=detokenizer.last_segment, - token=token, - logprobs=logprobs, - from_draft=from_draft, + token=last_token, + logprobs=last_logprobs, + from_draft=last_from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, - generation_tokens=n + 1, - generation_tps=(n + 1) / (time.perf_counter() - tic), + generation_tokens=total_tokens, + generation_tps=total_tokens / elapsed if elapsed > 0 else 0.0, peak_memory=mx.get_peak_memory() / 1e9, - finish_reason="stop" if token in tokenizer.eos_token_ids else "length", + finish_reason=finish_reason, ) @@ -752,7 +850,9 @@ def generate( verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. - See :func:`stream_generate` for more details. + See :func:`stream_generate` for more details. For diffusion models + (e.g., LLaDA2), additional options include ``block_length``, ``steps``, + and ``threshold``. """ if verbose: print("=" * 10) @@ -768,7 +868,7 @@ def generate( print("=" * 10) if len(text) == 0: print("No text generated for this prompt") - return + return text print( f"Prompt: {response.prompt_tokens} tokens, " f"{response.prompt_tps:.3f} tokens-per-sec" @@ -1156,6 +1256,13 @@ def batch_generate( See :obj:`BatchGenerator` for more details. """ + # Check if model uses custom generation (e.g., diffusion models) + if callable(getattr(model, "generate_step", None)): + raise NotImplementedError( + f"{model.__class__.__name__} uses custom generation and does not support " + "batch_generate(). Use generate() or stream_generate() instead." + ) + gen = BatchGenerator(model, stop_tokens=tokenizer.eos_token_ids, **kwargs) num_samples = len(prompts) fin = 0 @@ -1299,6 +1406,25 @@ def main(): raise ValueError("Draft model tokenizer does not match model tokenizer.") else: draft_model = None + + # Prepare generation kwargs + gen_kwargs = { + "max_tokens": args.max_tokens, + "verbose": args.verbose, + "max_kv_size": args.max_kv_size, + "prompt_cache": prompt_cache if using_cache else None, + "kv_bits": args.kv_bits, + "kv_group_size": args.kv_group_size, + "quantized_kv_start": args.quantized_kv_start, + "draft_model": draft_model, + "num_draft_tokens": args.num_draft_tokens, + # Diffusion-specific parameters + "block_length": args.block_length, + "steps": args.steps, + "threshold": args.threshold, + } + + # Create sampler for all models sampler = make_sampler( args.temp, args.top_p, @@ -1309,21 +1435,13 @@ def main(): xtc_threshold=args.xtc_threshold, xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids), ) - response = generate( - model, - tokenizer, - prompt, - max_tokens=args.max_tokens, - verbose=args.verbose, - sampler=sampler, - max_kv_size=args.max_kv_size, - prompt_cache=prompt_cache if using_cache else None, - kv_bits=args.kv_bits, - kv_group_size=args.kv_group_size, - quantized_kv_start=args.quantized_kv_start, - draft_model=draft_model, - num_draft_tokens=args.num_draft_tokens, - ) + + gen_kwargs["sampler"] = sampler + + # Generate + response = generate(model, tokenizer, prompt, **gen_kwargs) + + # Print response if not verbose (verbose mode prints during generation) if not args.verbose: print(response) diff --git a/mlx_lm/models/llada2_moe.py b/mlx_lm/models/llada2_moe.py new file mode 100644 index 00000000..0f1d3d4b --- /dev/null +++ b/mlx_lm/models/llada2_moe.py @@ -0,0 +1,538 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + intermediate_size: int + max_position_embeddings: int + moe_intermediate_size: int + num_experts: int + num_shared_experts: int + norm_topk_prob: bool + num_attention_heads: int + num_experts_per_tok: int + num_hidden_layers: int + num_key_value_heads: int + rms_norm_eps: float + rope_theta: float + vocab_size: int + first_k_dense_replace: int + head_dim: Optional[int] = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + use_bias: bool = False + use_qkv_bias: bool = False + norm_head: bool = False + norm_softmax: bool = False + tie_word_embeddings: bool = False + partial_rotary_factor: float = 1.0 + rotary_dim: Optional[int] = None + moe_router_enable_expert_bias: bool = False + routed_scaling_factor: float = 1.0 + score_function: str = "sigmoid" + n_group: int = 1 + topk_group: int = 4 + router_dtype: Optional[str] = None + mask_token_id: int = 156895 + eos_token_id: int = 156892 + + +@partial(mx.compile, shapeless=True) +def swiglu(gate, up): + return nn.silu(gate) * up + + +def is_eos_token(tokens: mx.array, eos_token_ids: set) -> mx.array: + """Check if tokens match any EOS token ID.""" + return (tokens[:, None] == mx.array(list(eos_token_ids))).any(axis=-1) + + +class LLaDA2MoeMLP(nn.Module): + def __init__(self, args: ModelArgs, intermediate_size: Optional[int] = None): + super().__init__() + self.intermediate_size = ( + intermediate_size + if intermediate_size is not None + else args.intermediate_size + ) + + self.gate_proj = nn.Linear( + args.hidden_size, self.intermediate_size, bias=args.use_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, args.hidden_size, bias=args.use_bias + ) + self.up_proj = nn.Linear( + args.hidden_size, self.intermediate_size, bias=args.use_bias + ) + + def __call__(self, x) -> mx.array: + return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) + + +class LLaDA2MoeAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads + self.head_dim = args.head_dim or (args.hidden_size // self.num_attention_heads) + self.scale = self.head_dim**-0.5 + + self.query_key_value = nn.Linear( + args.hidden_size, + (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=args.use_qkv_bias, + ) + self.dense = nn.Linear( + self.num_attention_heads * self.head_dim, + args.hidden_size, + bias=args.use_bias, + ) + self.query_layernorm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.key_layernorm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + + rope_dim = args.rotary_dim or int(self.head_dim * args.partial_rotary_factor) + self.rope = initialize_rope( + rope_dim, + args.rope_theta, + traditional=False, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): + B, L, _ = x.shape + + qkv = self.query_key_value(x) + q_size = self.num_attention_heads * self.head_dim + kv_size = self.num_key_value_heads * self.head_dim + q, k, v = mx.split(qkv, [q_size, q_size + kv_size], axis=-1) + + queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + keys = k.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = v.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + + queries = self.query_layernorm(queries) + keys = self.key_layernorm(keys) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.dense(output) + + +def group_expert_select( + gates, + expert_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob, + score_function, +): + scores = ( + mx.sigmoid(gates) + if score_function == "sigmoid" + else mx.softmax(gates, axis=-1, precise=True) + ) + orig_scores = scores + + if expert_bias is not None: + scores = scores + expert_bias + + if n_group > 1: + scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) + k = n_group - topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis( + scores, mx.stop_gradient(group_idx), mx.array(0.0, scores.dtype), axis=-2 + ) + scores = mx.flatten(scores, -2, -1) + + inds = mx.argpartition(scores, kth=-top_k, axis=-1)[..., -top_k:] + scores = mx.take_along_axis(orig_scores, inds, axis=-1) + + if top_k > 1 and norm_topk_prob: + scores = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) + + scores = scores * routed_scaling_factor + return inds, scores + + +class LLaDA2MoeGate(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.norm_topk_prob = args.norm_topk_prob + self.top_k = args.num_experts_per_tok + self.n_group = args.n_group + self.topk_group = args.topk_group + self.routed_scaling_factor = args.routed_scaling_factor + self.score_function = args.score_function + + if args.router_dtype == "fp32": + router_dtype = mx.float32 + else: + router_dtype = None + + self.weight = mx.zeros((args.num_experts, args.hidden_size), dtype=router_dtype) + self.expert_bias = ( + mx.zeros((args.num_experts,), dtype=router_dtype) + if args.moe_router_enable_expert_bias + else None + ) + + def __call__(self, x): + orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + gates = mx.matmul(x, self.weight.T) + + indices, scores = group_expert_select( + gates, + self.expert_bias, + self.top_k, + self.n_group, + self.topk_group, + self.routed_scaling_factor, + self.norm_topk_prob, + self.score_function, + ) + + return indices.reshape(*orig_shape[:-1], -1), scores.reshape( + *orig_shape[:-1], -1 + ) + + +class LLaDA2MoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts_per_tok = args.num_experts_per_tok + self.switch_mlp = SwitchGLU( + args.hidden_size, + args.moe_intermediate_size, + args.num_experts, + bias=args.use_bias, + ) + self.gate = LLaDA2MoeGate(args) + self.shared_experts = ( + LLaDA2MoeMLP( + args=args, + intermediate_size=args.moe_intermediate_size * args.num_shared_experts, + ) + if args.num_shared_experts > 0 + else None + ) + + def __call__(self, x): + inds, scores = self.gate(x) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + if self.shared_experts is not None: + y = y + self.shared_experts(x) + return y + + +class LLaDA2MoeDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.attention = LLaDA2MoeAttention(args) + self.mlp = ( + LLaDA2MoeSparseMoeBlock(args) + if layer_idx >= args.first_k_dense_replace + else LLaDA2MoeMLP(args) + ) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache=None): + r = self.attention(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + return h + r + + +class LLaDA2MoeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.word_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + LLaDA2MoeDecoderLayer(args, i) for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__(self, inputs, cache=None, inputs_embeds=None, mask=None): + h = inputs_embeds if inputs_embeds is not None else self.word_embeddings(inputs) + + if cache is None: + cache = [None] * len(self.layers) + if mask is None: + mask = create_attention_mask(h, cache[0]) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + """LLaDA2 MoE model with diffusion-based generation.""" + + EXTRA_EOS_TOKENS = ["<|role_end|>"] + + # As per original paper, LLaDA does not support kv caching + supports_prompt_cache = False + + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LLaDA2MoeModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs, cache=None, inputs_embeds=None, mask=None): + out = self.model(inputs, cache, inputs_embeds, mask) + if self.args.tie_word_embeddings: + return self.model.word_embeddings.as_linear(out) + return self.lm_head(out) + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return ( + self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads + ) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads + + def _create_block_diagonal_mask( + self, num_blocks: int, block_length: int, dtype=mx.float32 + ): + """Create block-diagonal attention mask for diffusion generation.""" + mask = mx.tril(mx.ones((num_blocks, num_blocks))) + mask = mx.repeat(mx.repeat(mask, block_length, axis=0), block_length, axis=1) + mask = mask[None, None, :, :] + return mx.where(mask, 0.0, float("-inf")).astype(dtype) + + def _select_tokens_to_update( + self, confidence: mx.array, mask: mx.array, num_tokens: int, threshold: float + ): + """Select which tokens to update based on confidence scores.""" + conf = mx.where(mask, confidence, float("-inf"))[0] + + high_conf = conf > threshold + if high_conf.sum().item() >= num_tokens: + return high_conf + + k = min(num_tokens, mask.sum().item()) + idx = mx.argpartition(-conf, kth=k - 1)[:k] + positions = mx.arange(len(conf)) + return (positions[:, None] == idx[None, :]).any(axis=1) + + def _find_stop_position(self, tokens: mx.array, mask_id: int, eos_ids: set): + """Find first mask or EOS position in token sequence.""" + is_mask = tokens == mask_id + is_eos = is_eos_token(tokens, eos_ids) + stop_mask = is_mask | is_eos + + if not stop_mask.any(): + return len(tokens), False + + stop_idx = mx.argmax(stop_mask.astype(mx.int32)).item() + return stop_idx, is_eos[stop_idx].item() if stop_idx < len(is_eos) else False + + def generate_step( + self, + inputs: mx.array, + max_tokens: int = 2048, + sampler: Optional[callable] = None, + block_length: int = 32, + steps: int = 32, + minimal_topk: int = 1, + threshold: float = 0.95, + eos_token_ids: Optional[Union[int, list, set]] = None, + mask_id: Optional[int] = None, + ): + """ + Diffusion-based text generation using block-wise iterative denoising. + + Args: + inputs: Input token IDs (prompt). + max_tokens: Maximum tokens to generate. + sampler: Sampling function from make_sampler(). + block_length: Size of each generation block. + steps: Number of denoising iterations per block. + minimal_topk: Minimum tokens to keep (caps effective steps). + threshold: Confidence threshold for token acceptance. + eos_token_ids: EOS token ID(s) (int, list, or set). + mask_id: Mask token ID for ungenerated positions. + + Yields: + (tokens, logprobs): Generated tokens and empty logprobs array. + """ + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + if eos_token_ids is None: + eos_token_ids = {self.args.eos_token_id} + elif not isinstance(eos_token_ids, set): + eos_token_ids = ( + {eos_token_ids} + if isinstance(eos_token_ids, int) + else set(eos_token_ids) + ) + + mask_id = mask_id or self.args.mask_token_id + steps = min(steps, max_tokens // minimal_topk) + + batch_size, prompt_length = inputs.shape + if batch_size != 1: + raise ValueError( + f"Diffusion generation only supports batch_size=1, got {batch_size}" + ) + + num_blocks = (prompt_length + max_tokens + block_length - 1) // block_length + total_length = num_blocks * block_length + + mask = self._create_block_diagonal_mask( + num_blocks, block_length, dtype=self.model.word_embeddings.weight.dtype + ) + transfer_schedule = self._get_num_transfer_tokens(block_length, steps) + + x = mx.full((1, total_length), mask_id, dtype=mx.int32) + x[:, :prompt_length] = inputs + + last_yield_pos = prompt_length + prefill_blocks = prompt_length // block_length + + for block_idx in range(prefill_blocks, num_blocks): + window_end = (block_idx + 1) * block_length + cur_x = x[:, :window_end] + cur_mask = mask[:, :, :window_end, :window_end] + block_start = block_idx * block_length + + for step in range(steps): + active_mask = cur_x[:, -block_length:] == mask_id + if not active_mask.any(): + break + + logits = self(cur_x, cache=None, mask=cur_mask) + tokens, confidence = self._sample_with_sampler( + logits[:, -block_length:, :], sampler + ) + + num_transfer = int(transfer_schedule[step]) + update_mask = self._select_tokens_to_update( + confidence, active_mask, num_transfer, threshold + ) + + if not update_mask.any(): + continue + + new_block = mx.where(update_mask, tokens[0], cur_x[0, -block_length:]) + cur_x = mx.concatenate( + [cur_x[:, :-block_length], new_block[None, :]], axis=1 + ) + x[:, :window_end] = cur_x + + start = max(last_yield_pos - block_start, 0) + if start >= block_length: + continue + + remaining = cur_x[0, -block_length:][start:] + stop_idx, hit_eos = self._find_stop_position( + remaining, mask_id, eos_token_ids + ) + + if stop_idx > 0: + end_idx = stop_idx + 1 if hit_eos else stop_idx + yield (remaining[:end_idx], mx.array([])) + last_yield_pos = block_start + start + end_idx + + if hit_eos: + return + + gen_end = min(window_end, prompt_length + max_tokens) + if gen_end > last_yield_pos: + remaining = x[0, last_yield_pos:gen_end] + stop_idx, hit_eos = self._find_stop_position( + remaining, mask_id, eos_token_ids + ) + + if stop_idx > 0: + end_idx = stop_idx + 1 if hit_eos else stop_idx + yield (remaining[:end_idx], mx.array([])) + last_yield_pos += end_idx + + if hit_eos: + return + + @staticmethod + def _get_num_transfer_tokens(block_length: int, steps: int): + """Calculate token transfer schedule for denoising steps.""" + if steps == 0: + return mx.array([], dtype=mx.int32) + + base = block_length // steps + remainder = block_length % steps + schedule = mx.full((steps,), base, dtype=mx.int32) + schedule[:remainder] += 1 + return schedule + + def _sample_with_sampler(self, logits: mx.array, sampler: callable): + """Sample tokens and return confidence scores.""" + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + tokens = sampler(logprobs) + probs = mx.exp(logprobs) + confidence = mx.take_along_axis(probs, tokens[..., None], axis=-1).squeeze(-1) + return tokens, confidence + + def sanitize(self, weights): + """Convert HuggingFace weights to MLX format by stacking MoE expert weights.""" + for l in range(self.args.first_k_dense_replace, self.args.num_hidden_layers): + prefix = f"model.layers.{l}.mlp" + for proj in ["gate_proj", "up_proj", "down_proj"]: + if f"{prefix}.experts.0.{proj}.weight" in weights: + stacked = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.{proj}.weight") + for e in range(self.args.num_experts) + ] + ) + weights[f"{prefix}.switch_mlp.{proj}.weight"] = stacked + return weights diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 410c83a0..340a5990 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -30,7 +30,7 @@ from .generate import stream_generate from .models.cache import can_trim_prompt_cache, make_prompt_cache, trim_prompt_cache from .sample_utils import make_logits_processors, make_sampler -from .utils import common_prefix_len, load +from .utils import common_prefix_len, does_model_support_prompt_cache, load def get_system_fingerprint(): @@ -333,7 +333,18 @@ def do_POST(self): self.logit_bias = self.body.get("logit_bias", None) self.logprobs = self.body.get("logprobs", -1) self.seed = self.body.get("seed", None) - self.validate_model_parameters() + # Diffusion-specific parameters + self.block_length = self.body.get("block_length", 32) + self.steps = self.body.get("steps", 32) + self.threshold = self.body.get("threshold", 0.95) + try: + self.validate_model_parameters() + except ValueError as e: + self._set_completion_headers(400) + self.end_headers() + self.wfile.write(json.dumps({"error": str(e)}).encode()) + return + if self.seed is not None: mx.random.seed(self.seed) # Load the model if needed @@ -435,6 +446,15 @@ def validate_model_parameters(self): if self.seed is not None and not isinstance(self.seed, int): raise ValueError("seed must be an integer") + if not isinstance(self.block_length, int) or self.block_length < 1: + raise ValueError("block_length must be a positive integer") + if not isinstance(self.steps, int) or self.steps < 1: + raise ValueError("steps must be a positive integer") + if not isinstance(self.threshold, (float, int)) or not ( + 0.0 <= self.threshold <= 1.0 + ): + raise ValueError("threshold must be a float between 0.0 and 1.0") + def generate_response( self, text: str, @@ -645,7 +665,9 @@ def handle_completion( token_logprobs = [] top_tokens = [] - prompt = self.get_prompt_cache(prompt) + use_cache = does_model_support_prompt_cache(self.model) + if use_cache: + prompt = self.get_prompt_cache(prompt) text = "" tic = time.perf_counter() @@ -695,10 +717,13 @@ def keepalive_callback(processed_tokens, total_tokens): max_tokens=self.max_tokens, sampler=sampler, logits_processors=logits_processors, - prompt_cache=self.prompt_cache.cache, + prompt_cache=self.prompt_cache.cache if use_cache else None, draft_model=self.model_provider.draft_model, num_draft_tokens=self.num_draft_tokens, prompt_progress_callback=keepalive_callback, + block_length=self.block_length, + steps=self.steps, + threshold=self.threshold, ): logging.debug(gen_response.text) @@ -720,16 +745,19 @@ def keepalive_callback(processed_tokens, total_tokens): token = gen_response.token logprobs = gen_response.logprobs tokens.append(token) - self.prompt_cache.tokens.append(token) - if self.logprobs > 0: - sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) - top_indices = sorted_indices[: self.logprobs] - top_logprobs = logprobs[top_indices] - top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(tuple(top_token_info)) + if use_cache: + self.prompt_cache.tokens.append(token) - token_logprobs.append(logprobs[token].item()) + if logprobs.size > 0: + if self.logprobs > 0: + sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) + top_indices = sorted_indices[: self.logprobs] + top_logprobs = logprobs[top_indices] + top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) + top_tokens.append(tuple(top_token_info)) + + token_logprobs.append(logprobs[token].item()) stop_condition = stopping_criteria( tokens, stop_id_sequences, self.tokenizer.eos_token_id @@ -777,11 +805,8 @@ def keepalive_callback(processed_tokens, total_tokens): self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() if self.stream_options is not None and self.stream_options["include_usage"]: - original_prompt_length = ( - len(self.prompt_cache.tokens) - len(tokens) + len(prompt) - ) response = self.completion_usage_response( - original_prompt_length, len(tokens) + gen_response.prompt_tokens, gen_response.generation_tokens ) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() @@ -791,8 +816,8 @@ def keepalive_callback(processed_tokens, total_tokens): response = self.generate_response( text, finish_reason, - len(prompt), - len(tokens), + gen_response.prompt_tokens, + gen_response.generation_tokens, token_logprobs=token_logprobs, top_tokens=top_tokens, tokens=tokens, diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 73fdcc97..cd014924 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -758,6 +758,23 @@ def common_prefix_len(list1, list2): return min_len +def does_model_support_prompt_cache(model: nn.Module) -> bool: + """ + Check if a model supports prompt caching. + + Models should explicitly declare cache support via the `supports_prompt_cache` + class attribute. If not declared, defaults to True (standard autoregressive + models use KV cache by default). + + Args: + model (nn.Module): The model to check. + + Returns: + bool: True if the model supports prompt caching, False otherwise. + """ + return getattr(model, "supports_prompt_cache", True) + + def does_model_support_input_embeddings(model: nn.Module) -> bool: """ Check if the model supports input_embeddings in its call signature. diff --git a/tests/test_models.py b/tests/test_models.py index 05719625..1cb6fae3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1296,6 +1296,36 @@ def test_gpt_oss(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_llada2_moe(self): + from mlx_lm.models import llada2_moe + + args = llada2_moe.ModelArgs( + model_type="llada2_moe", + hidden_size=128, + intermediate_size=256, + max_position_embeddings=1000, + moe_intermediate_size=256, + num_experts=4, + num_shared_experts=1, + norm_topk_prob=True, + num_attention_heads=4, + num_experts_per_tok=2, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_theta=1000, + vocab_size=1000, + first_k_dense_replace=2, + routed_scaling_factor=1.0, + score_function="sigmoid", + n_group=2, + topk_group=1, + ) + model = llada2_moe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_all_models(self): test_configs = [ { @@ -1939,6 +1969,28 @@ def test_all_models(self): "vocab_size": 32, "intermediate_size": 128, }, + { + "model_type": "llada2_moe", + "hidden_size": 128, + "intermediate_size": 256, + "max_position_embeddings": 1000, + "moe_intermediate_size": 256, + "num_experts": 4, + "num_shared_experts": 1, + "norm_topk_prob": True, + "num_attention_heads": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "rope_theta": 1000, + "vocab_size": 1000, + "first_k_dense_replace": 2, + "routed_scaling_factor": 1.0, + "score_function": "sigmoid", + "n_group": 2, + "topk_group": 1, + }, { "model_type": "minimax", "hidden_size": 128,