Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def _run_forward(
sequence_parallel_context=None,
) -> torch.Tensor:
using_torch_fallback = _ensure_linear_attention_kernels(mod)
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
modeling_module = import_module(mod.__class__.__module__)
apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states')
Comment thread
meichangsu1 marked this conversation as resolved.

local_attention_mask = attention_mask
if torch.is_tensor(attention_mask) and attention_mask.dim() == 2:
Expand Down
220 changes: 124 additions & 96 deletions src/twinkle/patch/gdn_padding_free.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import transformers
from importlib import import_module
from packaging.version import Version
from transformers.utils.import_utils import is_flash_linear_attention_available
from typing import Optional
Expand All @@ -12,16 +13,35 @@ def _is_qwen35_model(hf_config) -> bool:
return 'qwen3_5' in getattr(hf_config, 'model_type', '')


def _find_qwen35_classes(module: Optional[torch.nn.Module], hf_config, enable_sp: bool):
def _iter_qwen35_class_pairs():
class_specs = (
(
'transformers.models.qwen3_5.modeling_qwen3_5',
'Qwen3_5DecoderLayer',
'Qwen3_5GatedDeltaNet',
),
(
'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe',
'Qwen3_5MoeDecoderLayer',
'Qwen3_5MoeGatedDeltaNet',
),
)
for module_name, decoder_class_name, gdn_class_name in class_specs:
try:
modeling_module = import_module(module_name)
yield getattr(modeling_module, decoder_class_name), getattr(modeling_module, gdn_class_name)
except Exception:
continue


def _find_qwen35_class_pairs(module: Optional[torch.nn.Module], hf_config, enable_sp: bool):
if module is None or enable_sp or not _is_qwen35_model(hf_config):
return None, None
try:
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
except Exception:
return None, None
if any(isinstance(submodule, Qwen3_5GatedDeltaNet) for submodule in module.modules()):
return Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
return None, None
return ()
class_pairs = []
for decoder_layer_cls, gated_delta_net_cls in _iter_qwen35_class_pairs():
if any(isinstance(submodule, gated_delta_net_cls) for submodule in module.modules()):
class_pairs.append((decoder_layer_cls, gated_delta_net_cls))
return tuple(class_pairs)


def _get_flash_linear_attention_kernels():
Expand Down Expand Up @@ -84,104 +104,112 @@ class GatedDeltaNetPaddingFreePatch(Patch):

def __call__(self, module, *args, **kwargs):
del args
Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet = _find_qwen35_classes(
qwen35_class_pairs = _find_qwen35_class_pairs(
module,
kwargs.get('hf_config'),
bool(kwargs.get('enable_sp', False)),
)
if Qwen3_5DecoderLayer is None or Qwen3_5GatedDeltaNet is None:
return
if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False):
if not qwen35_class_pairs:
return
module._twinkle_gdn_padding_free_patched = True

if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False):
origin_decoder_forward = Qwen3_5DecoderLayer.forward

def decoder_forward(
layer,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values=None,
cache_position: Optional[torch.Tensor] = None,
**extra_kwargs,
):
if getattr(layer, 'layer_type', None) != 'linear_attention':
return call_with_supported_kwargs(
origin_decoder_forward,
layer,
for decoder_layer_cls, gated_delta_net_cls in qwen35_class_pairs:
if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False):
continue

if not getattr(decoder_layer_cls, '_twinkle_padding_free_cu_seqlens_patched', False):
origin_decoder_forward = decoder_layer_cls.forward

def decoder_forward(
layer,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values=None,
cache_position: Optional[torch.Tensor] = None,
_origin_decoder_forward=origin_decoder_forward,
**extra_kwargs,
):
if getattr(layer, 'layer_type', None) != 'linear_attention':
return call_with_supported_kwargs(
_origin_decoder_forward,
layer,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**extra_kwargs,
)
cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None)
extra_kwargs.pop('cu_seq_lens_k', None)
extra_kwargs.pop('max_length_q', None)
extra_kwargs.pop('max_length_k', None)

residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
hidden_states = layer.linear_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**extra_kwargs,
)
cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None)
extra_kwargs.pop('cu_seq_lens_k', None)
extra_kwargs.pop('max_length_q', None)
extra_kwargs.pop('max_length_k', None)

residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
hidden_states = layer.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
cu_seq_lens_q=cu_seq_lens_q,
)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

Qwen3_5DecoderLayer.forward = decoder_forward
Qwen3_5DecoderLayer._twinkle_padding_free_cu_seqlens_patched = True

if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False):
origin_forward = Qwen3_5GatedDeltaNet.forward
patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch()

def forward(
mod,
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: Optional[torch.Tensor] = None,
cu_seq_lens_q: Optional[torch.Tensor] = None,
**extra_kwargs,
):
if cu_seq_lens_q is None:
return call_with_supported_kwargs(
origin_forward,
mod,
hidden_states,
cache_params=cache_params,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
cu_seq_lens_q=cu_seq_lens_q,
**extra_kwargs,
)
return _patch_gdn_kernels_for_cu_seqlens(
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
Comment thread
meichangsu1 marked this conversation as resolved.

decoder_layer_cls.forward = decoder_forward
decoder_layer_cls._twinkle_padding_free_cu_seqlens_patched = True

if not getattr(gated_delta_net_cls, '_twinkle_padding_free_gdn_patched', False):
origin_forward = gated_delta_net_cls.forward
patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch()

def forward(
mod,
cu_seqlens=cu_seq_lens_q,
patch_chunk_rule=patch_chunk_rule,
origin_forward=origin_forward,
forward_args=(hidden_states, ),
forward_kwargs={
'cache_params': cache_params,
'cache_position': cache_position,
'attention_mask': attention_mask,
'cu_seq_lens_q': cu_seq_lens_q,
**extra_kwargs,
},
)
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: Optional[torch.Tensor] = None,
cu_seq_lens_q: Optional[torch.Tensor] = None,
_origin_forward=origin_forward,
_patch_chunk_rule=patch_chunk_rule,
**extra_kwargs,
):
if cu_seq_lens_q is None:
return call_with_supported_kwargs(
_origin_forward,
mod,
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
**extra_kwargs,
)
return _patch_gdn_kernels_for_cu_seqlens(
mod,
cu_seqlens=cu_seq_lens_q,
patch_chunk_rule=_patch_chunk_rule,
origin_forward=_origin_forward,
forward_args=(hidden_states, ),
forward_kwargs={
'cache_params': cache_params,
'cache_position': cache_position,
'attention_mask': attention_mask,
'cu_seq_lens_q': cu_seq_lens_q,
**extra_kwargs,
},
)

Qwen3_5GatedDeltaNet.forward = forward
Qwen3_5GatedDeltaNet._twinkle_padding_free_gdn_patched = True
gated_delta_net_cls.forward = forward
gated_delta_net_cls._twinkle_padding_free_gdn_patched = True
Loading
Loading