From 9c4ca46de94657cc41b814eb10aa8f3f202ef05b Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Wed, 10 Jun 2026 10:16:01 +0800 Subject: [PATCH 1/2] support qwen35 moe gdn sp --- .../strategy/sequence_parallel/__init__.py | 13 ++-- .../sequence_parallel/linear_attention_sp.py | 73 +++++++++++-------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 8388c8e4..d45ac618 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -13,7 +13,7 @@ from twinkle.utils import DeviceMesh from twinkle.utils.transformers_utils import get_llm_model from twinkle.utils.utils import call_with_supported_kwargs, has_signature_parameter -from .linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch +from .linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch, _iter_qwen35_gated_delta_net_classes from .utils import (DistributedAttention, GatherLoss, _derive_sequence_parallel_sizes, _get_seq_groups_from_device_mesh, _get_ulysses_size, _SeqAllToAll, get_config_attr, get_cu_seqlens_from_position_ids, is_hccl_backend, is_moe_config, post_all2all) @@ -409,7 +409,7 @@ def _deepstack_process(_self, hidden_states: torch.Tensor, visual_pos_masks: tor def _is_qwen35_model(model: torch.nn.Module) -> bool: config = getattr(model, 'config', None) model_type = str(getattr(config, 'model_type', '') or '') - if model_type == 'qwen3_5': + if model_type in {'qwen3_5', 'qwen3_5_moe'}: return True architectures = getattr(config, 'architectures', None) or [] @@ -417,16 +417,15 @@ def _is_qwen35_model(model: torch.nn.Module) -> bool: return True model_module = getattr(model.__class__, '__module__', '') or '' - return 'transformers.models.qwen3_5' in model_module + return 'transformers.models.qwen3_5' in model_module or 'transformers.models.qwen3_5_moe' in model_module def _prepare_qwen35_linear_attention(self, model: torch.nn.Module): has_qwen35_linear_attention = self._is_qwen35_model(model) if not has_qwen35_linear_attention: - try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5GatedDeltaNet - except Exception: + gated_delta_net_classes = tuple(_iter_qwen35_gated_delta_net_classes()) + if not gated_delta_net_classes: return - has_qwen35_linear_attention = any(isinstance(module, Qwen3_5GatedDeltaNet) for module in model.modules()) + has_qwen35_linear_attention = any(isinstance(module, gated_delta_net_classes) for module in model.modules()) if not has_qwen35_linear_attention: return if int(self.rp_world_size or 1) > 1: diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 4a033212..8560d4cd 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -2,6 +2,7 @@ import torch.distributed as dist import torch.nn.functional as F import warnings +from importlib import import_module from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from typing import Any, Optional, Tuple @@ -117,13 +118,27 @@ def _torch_causal_conv1d_fn( mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE return False - from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule + modeling_module = import_module(mod.__class__.__module__) + torch_chunk_gated_delta_rule = getattr(modeling_module, 'torch_chunk_gated_delta_rule') mod.causal_conv1d_fn = _torch_causal_conv1d_fn mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) return True +def _iter_qwen35_gated_delta_net_classes(): + class_specs = ( + ('transformers.models.qwen3_5.modeling_qwen3_5', 'Qwen3_5GatedDeltaNet'), + ('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', 'Qwen3_5MoeGatedDeltaNet'), + ) + for module_name, class_name in class_specs: + try: + modeling_module = import_module(module_name) + yield getattr(modeling_module, class_name) + except Exception: + continue + + def _get_local_conv_weights( mod: torch.nn.Module, *, @@ -310,41 +325,39 @@ def __call__(self, module, *args, **kwargs): raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' '(derived ring attention).') - try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5GatedDeltaNet - except Exception: - return + for gated_delta_net_cls in _iter_qwen35_gated_delta_net_classes(): + if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): + continue - if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): - return + origin_forward = gated_delta_net_cls.forward - origin_forward = Qwen3_5GatedDeltaNet.forward - - def sp_linear_forward( - mod, - hidden_states: torch.Tensor, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - **extra_kwargs, - ): - sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) - if not _sp_is_enabled(sequence_parallel_context): - return origin_forward( + def sp_linear_forward( + mod, + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + _origin_forward=origin_forward, + **extra_kwargs, + ): + sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) + if not _sp_is_enabled(sequence_parallel_context): + return _origin_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( mod, hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask, + sequence_parallel_context=sequence_parallel_context, ) - return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - sequence_parallel_context=sequence_parallel_context, - ) - Qwen3_5GatedDeltaNet.forward = sp_linear_forward - Qwen3_5GatedDeltaNet._twinkle_sp_linear_patched = True + gated_delta_net_cls.forward = sp_linear_forward + gated_delta_net_cls._twinkle_sp_linear_patched = True From 8a2f835c08e110ca8cc0eb1adddd5b1b1fd71358 Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Thu, 11 Jun 2026 10:20:35 +0800 Subject: [PATCH 2/2] support gdn padding-free & fix sp unit test --- .../sequence_parallel/linear_attention_sp.py | 3 +- src/twinkle/patch/gdn_padding_free.py | 220 ++++++++++-------- .../test_qwen35_linear_attention_sp.py | 96 ++++++-- 3 files changed, 199 insertions(+), 120 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 8560d4cd..66511b5b 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -181,7 +181,8 @@ def _run_forward( sequence_parallel_context=None, ) -> torch.Tensor: using_torch_fallback = _ensure_linear_attention_kernels(mod) - from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states + modeling_module = import_module(mod.__class__.__module__) + apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') local_attention_mask = attention_mask if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index 759a222f..5b54117c 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,5 +1,6 @@ import torch import transformers +from importlib import import_module from packaging.version import Version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -12,16 +13,35 @@ def _is_qwen35_model(hf_config) -> bool: return 'qwen3_5' in getattr(hf_config, 'model_type', '') -def _find_qwen35_classes(module: Optional[torch.nn.Module], hf_config, enable_sp: bool): +def _iter_qwen35_class_pairs(): + class_specs = ( + ( + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'Qwen3_5DecoderLayer', + 'Qwen3_5GatedDeltaNet', + ), + ( + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + 'Qwen3_5MoeDecoderLayer', + 'Qwen3_5MoeGatedDeltaNet', + ), + ) + for module_name, decoder_class_name, gdn_class_name in class_specs: + try: + modeling_module = import_module(module_name) + yield getattr(modeling_module, decoder_class_name), getattr(modeling_module, gdn_class_name) + except Exception: + continue + + +def _find_qwen35_class_pairs(module: Optional[torch.nn.Module], hf_config, enable_sp: bool): if module is None or enable_sp or not _is_qwen35_model(hf_config): - return None, None - try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet - except Exception: - return None, None - if any(isinstance(submodule, Qwen3_5GatedDeltaNet) for submodule in module.modules()): - return Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet - return None, None + return () + class_pairs = [] + for decoder_layer_cls, gated_delta_net_cls in _iter_qwen35_class_pairs(): + if any(isinstance(submodule, gated_delta_net_cls) for submodule in module.modules()): + class_pairs.append((decoder_layer_cls, gated_delta_net_cls)) + return tuple(class_pairs) def _get_flash_linear_attention_kernels(): @@ -84,104 +104,112 @@ class GatedDeltaNetPaddingFreePatch(Patch): def __call__(self, module, *args, **kwargs): del args - Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet = _find_qwen35_classes( + qwen35_class_pairs = _find_qwen35_class_pairs( module, kwargs.get('hf_config'), bool(kwargs.get('enable_sp', False)), ) - if Qwen3_5DecoderLayer is None or Qwen3_5GatedDeltaNet is None: - return - if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): + if not qwen35_class_pairs: return module._twinkle_gdn_padding_free_patched = True - if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): - origin_decoder_forward = Qwen3_5DecoderLayer.forward - - def decoder_forward( - layer, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values=None, - cache_position: Optional[torch.Tensor] = None, - **extra_kwargs, - ): - if getattr(layer, 'layer_type', None) != 'linear_attention': - return call_with_supported_kwargs( - origin_decoder_forward, - layer, + for decoder_layer_cls, gated_delta_net_cls in qwen35_class_pairs: + if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): + continue + + if not getattr(decoder_layer_cls, '_twinkle_padding_free_cu_seqlens_patched', False): + origin_decoder_forward = decoder_layer_cls.forward + + def decoder_forward( + layer, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values=None, + cache_position: Optional[torch.Tensor] = None, + _origin_decoder_forward=origin_decoder_forward, + **extra_kwargs, + ): + if getattr(layer, 'layer_type', None) != 'linear_attention': + return call_with_supported_kwargs( + _origin_decoder_forward, + layer, + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **extra_kwargs, + ) + cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) + extra_kwargs.pop('cu_seq_lens_k', None) + extra_kwargs.pop('max_length_q', None) + extra_kwargs.pop('max_length_k', None) + + residual = hidden_states + hidden_states = layer.input_layernorm(hidden_states) + hidden_states = layer.linear_attn( hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - cache_position=cache_position, - **extra_kwargs, - ) - cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) - extra_kwargs.pop('cu_seq_lens_k', None) - extra_kwargs.pop('max_length_q', None) - extra_kwargs.pop('max_length_k', None) - - residual = hidden_states - hidden_states = layer.input_layernorm(hidden_states) - hidden_states = layer.linear_attn( - hidden_states=hidden_states, - cache_params=past_key_values, - cache_position=cache_position, - attention_mask=attention_mask, - cu_seq_lens_q=cu_seq_lens_q, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = layer.post_attention_layernorm(hidden_states) - hidden_states = layer.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - Qwen3_5DecoderLayer.forward = decoder_forward - Qwen3_5DecoderLayer._twinkle_padding_free_cu_seqlens_patched = True - - if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False): - origin_forward = Qwen3_5GatedDeltaNet.forward - patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch() - - def forward( - mod, - hidden_states: torch.Tensor, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - cu_seq_lens_q: Optional[torch.Tensor] = None, - **extra_kwargs, - ): - if cu_seq_lens_q is None: - return call_with_supported_kwargs( - origin_forward, - mod, - hidden_states, - cache_params=cache_params, + cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, **extra_kwargs, ) - return _patch_gdn_kernels_for_cu_seqlens( + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = layer.post_attention_layernorm(hidden_states) + hidden_states = layer.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + return hidden_states + + decoder_layer_cls.forward = decoder_forward + decoder_layer_cls._twinkle_padding_free_cu_seqlens_patched = True + + if not getattr(gated_delta_net_cls, '_twinkle_padding_free_gdn_patched', False): + origin_forward = gated_delta_net_cls.forward + patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch() + + def forward( mod, - cu_seqlens=cu_seq_lens_q, - patch_chunk_rule=patch_chunk_rule, - origin_forward=origin_forward, - forward_args=(hidden_states, ), - forward_kwargs={ - 'cache_params': cache_params, - 'cache_position': cache_position, - 'attention_mask': attention_mask, - 'cu_seq_lens_q': cu_seq_lens_q, - **extra_kwargs, - }, - ) + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + cu_seq_lens_q: Optional[torch.Tensor] = None, + _origin_forward=origin_forward, + _patch_chunk_rule=patch_chunk_rule, + **extra_kwargs, + ): + if cu_seq_lens_q is None: + return call_with_supported_kwargs( + _origin_forward, + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return _patch_gdn_kernels_for_cu_seqlens( + mod, + cu_seqlens=cu_seq_lens_q, + patch_chunk_rule=_patch_chunk_rule, + origin_forward=_origin_forward, + forward_args=(hidden_states, ), + forward_kwargs={ + 'cache_params': cache_params, + 'cache_position': cache_position, + 'attention_mask': attention_mask, + 'cu_seq_lens_q': cu_seq_lens_q, + **extra_kwargs, + }, + ) - Qwen3_5GatedDeltaNet.forward = forward - Qwen3_5GatedDeltaNet._twinkle_padding_free_gdn_patched = True + gated_delta_net_cls.forward = forward + gated_delta_net_cls._twinkle_padding_free_gdn_patched = True diff --git a/tests/transformers/test_qwen35_linear_attention_sp.py b/tests/transformers/test_qwen35_linear_attention_sp.py index 848fbb61..7a99baac 100644 --- a/tests/transformers/test_qwen35_linear_attention_sp.py +++ b/tests/transformers/test_qwen35_linear_attention_sp.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -import torch.nn.functional as F import unittest from datetime import timedelta from transformers.modeling_flash_attention_utils import is_flash_attn_available @@ -15,20 +14,26 @@ from twinkle.loss import CrossEntropyLoss from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelStrategy, sequence_parallel from twinkle.model.transformers.strategy.sequence_parallel.linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch -from twinkle.model.transformers.strategy.sequence_parallel.utils import get_cu_seqlens_from_position_ids from twinkle.utils import DeviceMesh, selective_log_softmax try: from transformers import Qwen3_5ForCausalLM, Qwen3_5TextConfig - from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 _HAS_QWEN35 = True except Exception: Qwen3_5ForCausalLM = None Qwen3_5TextConfig = None - hf_qwen35 = None _HAS_QWEN35 = False +try: + from transformers import Qwen3_5MoeForCausalLM, Qwen3_5MoeTextConfig + + _HAS_QWEN35_MOE = True +except Exception: + Qwen3_5MoeForCausalLM = None + Qwen3_5MoeTextConfig = None + _HAS_QWEN35_MOE = False + if is_flash_linear_attention_available(): from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE @@ -42,7 +47,7 @@ LOSS_ATOL = 5e-3 GRAD_RTOL = 5e-3 GRAD_ATOL = 2e-3 -_HAS_FLA_PREFILL = bool(_HAS_QWEN35 and _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None) +_HAS_FLA_PREFILL = bool(_FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None) def _hf_compatible_fla_causal_conv1d_fn(x, weight, bias=None, activation=None, seq_idx=None): @@ -59,7 +64,7 @@ def _hf_compatible_fla_causal_conv1d_fn(x, weight, bias=None, activation=None, s return mixed_qkv.transpose(1, 2).contiguous() -def _force_fla_causal_conv(model: Qwen3_5ForCausalLM) -> Qwen3_5ForCausalLM: +def _force_fla_causal_conv(model: torch.nn.Module) -> torch.nn.Module: for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: @@ -68,26 +73,26 @@ def _force_fla_causal_conv(model: Qwen3_5ForCausalLM) -> Qwen3_5ForCausalLM: return model -def _force_packed_linear_attention(model: Qwen3_5ForCausalLM, position_ids: torch.Tensor) -> Qwen3_5ForCausalLM: - packed_cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32) +def _force_packed_linear_attention(model: torch.nn.Module, position_ids: torch.Tensor) -> torch.nn.Module: - def _make_packed_forward(cu_seqlens: torch.Tensor): + def _make_packed_forward(real_position_ids: torch.Tensor): def _packed_forward(mod, hidden_states, cache_params=None, cache_position=None, attention_mask=None): packed_ctx = SimpleNamespace( world_size=1, sp_world_size=1, + real_position_ids=real_position_ids.to(device=hidden_states.device), extra_kwargs={ - 'is_packed': True, - 'cu_seq_lens_q': cu_seqlens.to(dtype=torch.int32, device=hidden_states.device), + 'padding_free': True, }) + packed_ctx._extract_real_position_ids = lambda ids: ids + packed_ctx.pad = lambda tensor, padding_value=-1, position_ids=None: tensor return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( mod, hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask, - cu_seq_lens_q=packed_ctx.extra_kwargs['cu_seq_lens_q'], sequence_parallel_context=packed_ctx, ) @@ -96,7 +101,7 @@ def _packed_forward(mod, hidden_states, cache_params=None, cache_position=None, for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: - linear_attn.forward = MethodType(_make_packed_forward(packed_cu_seqlens), linear_attn) + linear_attn.forward = MethodType(_make_packed_forward(position_ids), linear_attn) return model @@ -146,13 +151,13 @@ def _model_dtype() -> torch.dtype: def _build_tiny_qwen35(device: torch.device, *, attn_implementation: str = 'sdpa', - layer_types: list[str] | None = None) -> Qwen3_5ForCausalLM: + layer_types: list[str] | None = None, + model_kind: str = 'dense') -> torch.nn.Module: if layer_types is None: layer_types = ['linear_attention', 'linear_attention'] - config = Qwen3_5TextConfig( + common_config_kwargs = dict( vocab_size=128, hidden_size=64, - intermediate_size=256, num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, @@ -169,15 +174,33 @@ def _build_tiny_qwen35(device: torch.device, attention_dropout=0.0, use_cache=False, ) - config._attn_implementation = attn_implementation - model = Qwen3_5ForCausalLM(config) + if model_kind == 'dense': + config = Qwen3_5TextConfig( + intermediate_size=256, + **common_config_kwargs, + ) + config._attn_implementation = attn_implementation + model = Qwen3_5ForCausalLM(config) + elif model_kind == 'moe': + config = Qwen3_5MoeTextConfig( + moe_intermediate_size=64, + shared_expert_intermediate_size=64, + num_experts=4, + num_experts_per_tok=2, + output_router_logits=False, + **common_config_kwargs, + ) + config._attn_implementation = attn_implementation + model = Qwen3_5MoeForCausalLM(config) + else: + raise ValueError(f'Unknown Qwen3.5 test model kind: {model_kind}') model = _force_fla_causal_conv(model) model.to(device=device, dtype=_model_dtype()) model.eval() return model -def _make_strategy(model: Qwen3_5ForCausalLM, world_size: int) -> SequenceParallelStrategy: +def _make_strategy(model: torch.nn.Module, world_size: int) -> SequenceParallelStrategy: strategy = SequenceParallelStrategy( device_mesh=DeviceMesh.from_sizes( world_size=world_size, @@ -234,7 +257,7 @@ def _make_packed_train_batch(device: torch.device): return input_ids, attention_mask, position_ids, labels -def _get_qkv_weight(model: Qwen3_5ForCausalLM) -> torch.nn.Parameter: +def _get_qkv_weight(model: torch.nn.Module) -> torch.nn.Parameter: for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: @@ -261,7 +284,7 @@ def _compute_training_path_loss( return result['loss'], num_tokens -def _average_qkv_grad_over_group(model: Qwen3_5ForCausalLM, group: dist.ProcessGroup | None) -> torch.Tensor: +def _average_qkv_grad_over_group(model: torch.nn.Module, group: dist.ProcessGroup | None) -> torch.Tensor: grad = _get_qkv_weight(model).grad if grad is None: raise AssertionError('No qkv gradient collected from Qwen3.5 linear attention layer.') @@ -280,12 +303,18 @@ def _run_prefill_alignment_worker(rank: int, port: int, attn_implementation: str = 'sdpa', layer_types: list[str] | None = None, - packed: bool = False): + packed: bool = False, + model_kind: str = 'dense'): device = _init_dist(rank, world_size, port) try: _set_determinism(1234) - baseline_model = _build_tiny_qwen35(device, attn_implementation=attn_implementation, layer_types=layer_types) + baseline_model = _build_tiny_qwen35( + device, + attn_implementation=attn_implementation, + layer_types=layer_types, + model_kind=model_kind, + ) sp_model = copy.deepcopy(baseline_model) input_ids, attention_mask, position_ids, labels = ( _make_packed_train_batch(device) if packed else _make_train_batch(device)) @@ -309,6 +338,7 @@ def _run_prefill_alignment_worker(rank: int, 'input_ids': input_ids, 'position_ids': position_ids, 'labels': labels, + 'padding_free': packed, }) local_labels = processed_inputs['labels'] sp_outputs = sp_model( @@ -367,6 +397,26 @@ def test_qwen35_mixed_attention_prefill_logits_and_qkv_grad_alignment(self): join=True, ) + @unittest.skipUnless(_HAS_QWEN35_MOE, 'transformers Qwen3.5-MoE is not available in this environment') + def test_qwen35_moe_linear_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['linear_attention', 'linear_attention'], False, 'moe'), + nprocs=WORLD_SIZE, + join=True, + ) + + @unittest.skipUnless(_HAS_QWEN35_MOE, 'transformers Qwen3.5-MoE is not available in this environment') + def test_qwen35_moe_mixed_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['full_attention', 'linear_attention'], False, 'moe'), + nprocs=WORLD_SIZE, + join=True, + ) + @unittest.skipUnless(is_flash_attn_available(), 'requires flash_attention_2 support in transformers') def test_qwen35_linear_attention_prefill_logits_and_qkv_grad_alignment_fa2(self): port = _find_free_port()