diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 8560d4cd..66511b5b 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -181,7 +181,8 @@ def _run_forward( sequence_parallel_context=None, ) -> torch.Tensor: using_torch_fallback = _ensure_linear_attention_kernels(mod) - from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states + modeling_module = import_module(mod.__class__.__module__) + apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') local_attention_mask = attention_mask if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index 759a222f..5b54117c 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,5 +1,6 @@ import torch import transformers +from importlib import import_module from packaging.version import Version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -12,16 +13,35 @@ def _is_qwen35_model(hf_config) -> bool: return 'qwen3_5' in getattr(hf_config, 'model_type', '') -def _find_qwen35_classes(module: Optional[torch.nn.Module], hf_config, enable_sp: bool): +def _iter_qwen35_class_pairs(): + class_specs = ( + ( + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'Qwen3_5DecoderLayer', + 'Qwen3_5GatedDeltaNet', + ), + ( + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + 'Qwen3_5MoeDecoderLayer', + 'Qwen3_5MoeGatedDeltaNet', + ), + ) + for module_name, decoder_class_name, gdn_class_name in class_specs: + try: + modeling_module = import_module(module_name) + yield getattr(modeling_module, decoder_class_name), getattr(modeling_module, gdn_class_name) + except Exception: + continue + + +def _find_qwen35_class_pairs(module: Optional[torch.nn.Module], hf_config, enable_sp: bool): if module is None or enable_sp or not _is_qwen35_model(hf_config): - return None, None - try: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet - except Exception: - return None, None - if any(isinstance(submodule, Qwen3_5GatedDeltaNet) for submodule in module.modules()): - return Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet - return None, None + return () + class_pairs = [] + for decoder_layer_cls, gated_delta_net_cls in _iter_qwen35_class_pairs(): + if any(isinstance(submodule, gated_delta_net_cls) for submodule in module.modules()): + class_pairs.append((decoder_layer_cls, gated_delta_net_cls)) + return tuple(class_pairs) def _get_flash_linear_attention_kernels(): @@ -84,104 +104,112 @@ class GatedDeltaNetPaddingFreePatch(Patch): def __call__(self, module, *args, **kwargs): del args - Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet = _find_qwen35_classes( + qwen35_class_pairs = _find_qwen35_class_pairs( module, kwargs.get('hf_config'), bool(kwargs.get('enable_sp', False)), ) - if Qwen3_5DecoderLayer is None or Qwen3_5GatedDeltaNet is None: - return - if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): + if not qwen35_class_pairs: return module._twinkle_gdn_padding_free_patched = True - if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): - origin_decoder_forward = Qwen3_5DecoderLayer.forward - - def decoder_forward( - layer, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values=None, - cache_position: Optional[torch.Tensor] = None, - **extra_kwargs, - ): - if getattr(layer, 'layer_type', None) != 'linear_attention': - return call_with_supported_kwargs( - origin_decoder_forward, - layer, + for decoder_layer_cls, gated_delta_net_cls in qwen35_class_pairs: + if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): + continue + + if not getattr(decoder_layer_cls, '_twinkle_padding_free_cu_seqlens_patched', False): + origin_decoder_forward = decoder_layer_cls.forward + + def decoder_forward( + layer, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values=None, + cache_position: Optional[torch.Tensor] = None, + _origin_decoder_forward=origin_decoder_forward, + **extra_kwargs, + ): + if getattr(layer, 'layer_type', None) != 'linear_attention': + return call_with_supported_kwargs( + _origin_decoder_forward, + layer, + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **extra_kwargs, + ) + cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) + extra_kwargs.pop('cu_seq_lens_k', None) + extra_kwargs.pop('max_length_q', None) + extra_kwargs.pop('max_length_k', None) + + residual = hidden_states + hidden_states = layer.input_layernorm(hidden_states) + hidden_states = layer.linear_attn( hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - cache_position=cache_position, - **extra_kwargs, - ) - cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) - extra_kwargs.pop('cu_seq_lens_k', None) - extra_kwargs.pop('max_length_q', None) - extra_kwargs.pop('max_length_k', None) - - residual = hidden_states - hidden_states = layer.input_layernorm(hidden_states) - hidden_states = layer.linear_attn( - hidden_states=hidden_states, - cache_params=past_key_values, - cache_position=cache_position, - attention_mask=attention_mask, - cu_seq_lens_q=cu_seq_lens_q, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = layer.post_attention_layernorm(hidden_states) - hidden_states = layer.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - Qwen3_5DecoderLayer.forward = decoder_forward - Qwen3_5DecoderLayer._twinkle_padding_free_cu_seqlens_patched = True - - if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False): - origin_forward = Qwen3_5GatedDeltaNet.forward - patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch() - - def forward( - mod, - hidden_states: torch.Tensor, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - cu_seq_lens_q: Optional[torch.Tensor] = None, - **extra_kwargs, - ): - if cu_seq_lens_q is None: - return call_with_supported_kwargs( - origin_forward, - mod, - hidden_states, - cache_params=cache_params, + cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, **extra_kwargs, ) - return _patch_gdn_kernels_for_cu_seqlens( + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = layer.post_attention_layernorm(hidden_states) + hidden_states = layer.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + return hidden_states + + decoder_layer_cls.forward = decoder_forward + decoder_layer_cls._twinkle_padding_free_cu_seqlens_patched = True + + if not getattr(gated_delta_net_cls, '_twinkle_padding_free_gdn_patched', False): + origin_forward = gated_delta_net_cls.forward + patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch() + + def forward( mod, - cu_seqlens=cu_seq_lens_q, - patch_chunk_rule=patch_chunk_rule, - origin_forward=origin_forward, - forward_args=(hidden_states, ), - forward_kwargs={ - 'cache_params': cache_params, - 'cache_position': cache_position, - 'attention_mask': attention_mask, - 'cu_seq_lens_q': cu_seq_lens_q, - **extra_kwargs, - }, - ) + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + cu_seq_lens_q: Optional[torch.Tensor] = None, + _origin_forward=origin_forward, + _patch_chunk_rule=patch_chunk_rule, + **extra_kwargs, + ): + if cu_seq_lens_q is None: + return call_with_supported_kwargs( + _origin_forward, + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return _patch_gdn_kernels_for_cu_seqlens( + mod, + cu_seqlens=cu_seq_lens_q, + patch_chunk_rule=_patch_chunk_rule, + origin_forward=_origin_forward, + forward_args=(hidden_states, ), + forward_kwargs={ + 'cache_params': cache_params, + 'cache_position': cache_position, + 'attention_mask': attention_mask, + 'cu_seq_lens_q': cu_seq_lens_q, + **extra_kwargs, + }, + ) - Qwen3_5GatedDeltaNet.forward = forward - Qwen3_5GatedDeltaNet._twinkle_padding_free_gdn_patched = True + gated_delta_net_cls.forward = forward + gated_delta_net_cls._twinkle_padding_free_gdn_patched = True diff --git a/tests/transformers/test_qwen35_linear_attention_sp.py b/tests/transformers/test_qwen35_linear_attention_sp.py index d1fe9126..d46c423c 100644 --- a/tests/transformers/test_qwen35_linear_attention_sp.py +++ b/tests/transformers/test_qwen35_linear_attention_sp.py @@ -6,7 +6,6 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -import torch.nn.functional as F from datetime import timedelta from transformers.modeling_flash_attention_utils import is_flash_attn_available from transformers.utils.import_utils import is_flash_linear_attention_available @@ -15,20 +14,26 @@ from twinkle.loss import CrossEntropyLoss from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelStrategy, sequence_parallel from twinkle.model.transformers.strategy.sequence_parallel.linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch -from twinkle.model.transformers.strategy.sequence_parallel.utils import get_cu_seqlens_from_position_ids from twinkle.utils import DeviceMesh, selective_log_softmax try: from transformers import Qwen3_5ForCausalLM, Qwen3_5TextConfig - from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 _HAS_QWEN35 = True except Exception: Qwen3_5ForCausalLM = None Qwen3_5TextConfig = None - hf_qwen35 = None _HAS_QWEN35 = False +try: + from transformers import Qwen3_5MoeForCausalLM, Qwen3_5MoeTextConfig + + _HAS_QWEN35_MOE = True +except Exception: + Qwen3_5MoeForCausalLM = None + Qwen3_5MoeTextConfig = None + _HAS_QWEN35_MOE = False + if is_flash_linear_attention_available(): from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE @@ -45,7 +50,7 @@ LOSS_ATOL = 5e-3 GRAD_RTOL = 5e-3 GRAD_ATOL = 2e-3 -_HAS_FLA_PREFILL = bool(_HAS_QWEN35 and _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None) +_HAS_FLA_PREFILL = bool(_FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None) def _hf_compatible_fla_causal_conv1d_fn(x, weight, bias=None, activation=None, seq_idx=None): @@ -62,7 +67,7 @@ def _hf_compatible_fla_causal_conv1d_fn(x, weight, bias=None, activation=None, s return mixed_qkv.transpose(1, 2).contiguous() -def _force_fla_causal_conv(model: Qwen3_5ForCausalLM) -> Qwen3_5ForCausalLM: +def _force_fla_causal_conv(model: torch.nn.Module) -> torch.nn.Module: for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: @@ -71,26 +76,26 @@ def _force_fla_causal_conv(model: Qwen3_5ForCausalLM) -> Qwen3_5ForCausalLM: return model -def _force_packed_linear_attention(model: Qwen3_5ForCausalLM, position_ids: torch.Tensor) -> Qwen3_5ForCausalLM: - packed_cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32) +def _force_packed_linear_attention(model: torch.nn.Module, position_ids: torch.Tensor) -> torch.nn.Module: - def _make_packed_forward(cu_seqlens: torch.Tensor): + def _make_packed_forward(real_position_ids: torch.Tensor): def _packed_forward(mod, hidden_states, cache_params=None, cache_position=None, attention_mask=None): packed_ctx = SimpleNamespace( world_size=1, sp_world_size=1, + real_position_ids=real_position_ids.to(device=hidden_states.device), extra_kwargs={ - 'is_packed': True, - 'cu_seq_lens_q': cu_seqlens.to(dtype=torch.int32, device=hidden_states.device), + 'padding_free': True, }) + packed_ctx._extract_real_position_ids = lambda ids: ids + packed_ctx.pad = lambda tensor, padding_value=-1, position_ids=None: tensor return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( mod, hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask, - cu_seq_lens_q=packed_ctx.extra_kwargs['cu_seq_lens_q'], sequence_parallel_context=packed_ctx, ) @@ -99,7 +104,7 @@ def _packed_forward(mod, hidden_states, cache_params=None, cache_position=None, for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: - linear_attn.forward = MethodType(_make_packed_forward(packed_cu_seqlens), linear_attn) + linear_attn.forward = MethodType(_make_packed_forward(position_ids), linear_attn) return model @@ -149,13 +154,13 @@ def _model_dtype() -> torch.dtype: def _build_tiny_qwen35(device: torch.device, *, attn_implementation: str = 'sdpa', - layer_types: list[str] | None = None) -> Qwen3_5ForCausalLM: + layer_types: list[str] | None = None, + model_kind: str = 'dense') -> torch.nn.Module: if layer_types is None: layer_types = ['linear_attention', 'linear_attention'] - config = Qwen3_5TextConfig( + common_config_kwargs = dict( vocab_size=128, hidden_size=64, - intermediate_size=256, num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, @@ -172,15 +177,33 @@ def _build_tiny_qwen35(device: torch.device, attention_dropout=0.0, use_cache=False, ) - config._attn_implementation = attn_implementation - model = Qwen3_5ForCausalLM(config) + if model_kind == 'dense': + config = Qwen3_5TextConfig( + intermediate_size=256, + **common_config_kwargs, + ) + config._attn_implementation = attn_implementation + model = Qwen3_5ForCausalLM(config) + elif model_kind == 'moe': + config = Qwen3_5MoeTextConfig( + moe_intermediate_size=64, + shared_expert_intermediate_size=64, + num_experts=4, + num_experts_per_tok=2, + output_router_logits=False, + **common_config_kwargs, + ) + config._attn_implementation = attn_implementation + model = Qwen3_5MoeForCausalLM(config) + else: + raise ValueError(f'Unknown Qwen3.5 test model kind: {model_kind}') model = _force_fla_causal_conv(model) model.to(device=device, dtype=_model_dtype()) model.eval() return model -def _make_strategy(model: Qwen3_5ForCausalLM, world_size: int) -> SequenceParallelStrategy: +def _make_strategy(model: torch.nn.Module, world_size: int) -> SequenceParallelStrategy: strategy = SequenceParallelStrategy( device_mesh=DeviceMesh.from_sizes( world_size=world_size, @@ -237,7 +260,7 @@ def _make_packed_train_batch(device: torch.device): return input_ids, attention_mask, position_ids, labels -def _get_qkv_weight(model: Qwen3_5ForCausalLM) -> torch.nn.Parameter: +def _get_qkv_weight(model: torch.nn.Module) -> torch.nn.Parameter: for layer in model.model.layers: linear_attn = getattr(layer, 'linear_attn', None) if linear_attn is not None: @@ -264,7 +287,7 @@ def _compute_training_path_loss( return result['loss'], num_tokens -def _average_qkv_grad_over_group(model: Qwen3_5ForCausalLM, group: dist.ProcessGroup | None) -> torch.Tensor: +def _average_qkv_grad_over_group(model: torch.nn.Module, group: dist.ProcessGroup | None) -> torch.Tensor: grad = _get_qkv_weight(model).grad if grad is None: raise AssertionError('No qkv gradient collected from Qwen3.5 linear attention layer.') @@ -283,12 +306,18 @@ def _run_prefill_alignment_worker(rank: int, port: int, attn_implementation: str = 'sdpa', layer_types: list[str] | None = None, - packed: bool = False): + packed: bool = False, + model_kind: str = 'dense'): device = _init_dist(rank, world_size, port) try: _set_determinism(1234) - baseline_model = _build_tiny_qwen35(device, attn_implementation=attn_implementation, layer_types=layer_types) + baseline_model = _build_tiny_qwen35( + device, + attn_implementation=attn_implementation, + layer_types=layer_types, + model_kind=model_kind, + ) sp_model = copy.deepcopy(baseline_model) input_ids, attention_mask, position_ids, labels = ( _make_packed_train_batch(device) if packed else _make_train_batch(device)) @@ -312,6 +341,7 @@ def _run_prefill_alignment_worker(rank: int, 'input_ids': input_ids, 'position_ids': position_ids, 'labels': labels, + 'padding_free': packed, }) local_labels = processed_inputs['labels'] sp_outputs = sp_model( @@ -372,6 +402,26 @@ def test_qwen35_mixed_attention_prefill_logits_and_qkv_grad_alignment(self): join=True, ) + @pytest.mark.skipif(not _HAS_QWEN35_MOE, reason='transformers Qwen3.5-MoE is not available in this environment') + def test_qwen35_moe_linear_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['linear_attention', 'linear_attention'], False, 'moe'), + nprocs=WORLD_SIZE, + join=True, + ) + + @pytest.mark.skipif(not _HAS_QWEN35_MOE, reason='transformers Qwen3.5-MoE is not available in this environment') + def test_qwen35_moe_mixed_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['full_attention', 'linear_attention'], False, 'moe'), + nprocs=WORLD_SIZE, + join=True, + ) + @pytest.mark.skipif(not is_flash_attn_available(), reason='requires flash_attention_2 support in transformers') def test_qwen35_linear_attention_prefill_logits_and_qkv_grad_alignment_fa2(self): port = _find_free_port()