Skip to content

Commit b04aae7

Browse files
manueldepradaCyrilvallez
authored andcommitted
Fix Cache.max_cache_len max value for Hybrid models (#39737)
* fix gemma * fix min * fix quant init issue * fix gemma 3n * skip quant cache test * fix modular * new test for Gemma * include cyril change --------- Co-authored-by: Cyril Vallez <[email protected]>
1 parent 0297e59 commit b04aae7

File tree

5 files changed

+82
-40
lines changed

5 files changed

+82
-40
lines changed

src/transformers/cache_utils.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,9 @@ def __init__(self, sliding_window, *args, **kwargs):
325325
sliding_window (`int`):
326326
Effective window size: number of tokens that are kept on each update call.
327327
"""
328-
kwargs.pop("max_cache_len", None)
329-
super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs)
328+
max_cache_len = kwargs.pop("max_cache_len", None)
329+
max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window
330+
super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs)
330331

331332
def update(
332333
self,
@@ -1277,9 +1278,7 @@ def max_batch_size(self) -> int:
12771278
def max_cache_len(self) -> int:
12781279
"""Return the maximum cache length of the cache"""
12791280
values = [layer.max_cache_len for layer in self.layers]
1280-
if len(set(values)) > 1:
1281-
raise ValueError(f"Max cache length is not consistent across layers: {values}")
1282-
return values[0]
1281+
return max(values)
12831282

12841283
@property
12851284
def is_compileable(self) -> bool:
@@ -1655,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache):
16551654
"""
16561655

16571656
def __init__(self, **kwargs) -> None:
1658-
Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)
1657+
DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)
16591658

16601659

16611660
class HQQQuantizedCache(QuantizedCache):
@@ -1697,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache):
16971696

16981697
def __init__(self, backend="HQQ", **kwargs) -> None:
16991698
assert backend == "HQQ"
1700-
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
1699+
DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
17011700

17021701

17031702
class EncoderDecoderCache(Cache):
@@ -1951,10 +1950,6 @@ def parse_layer_args_from_model_config(
19511950
)
19521951
# Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
19531952
max_cache_len = max_cache_len or config.max_position_embeddings
1954-
if getattr(config, "sliding_window", None) is not None:
1955-
sliding_window_len = min(config.sliding_window, max_cache_len)
1956-
else:
1957-
sliding_window_len = None
19581953
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
19591954
head_dim = (
19601955
config.head_dim
@@ -1981,7 +1976,7 @@ def parse_layer_args_from_model_config(
19811976
"layer_device_map": layer_device_map,
19821977
"head_dim": head_dim,
19831978
"num_heads": num_heads,
1984-
"sliding_window": sliding_window_len,
1979+
"sliding_window": getattr(config, "sliding_window", None),
19851980
}
19861981
return {k: v for k, v in layer_args.items() if v is not None}
19871982

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import torch.nn.functional as F
3131

3232
from ...activations import ACT2FN
33-
from ...cache_utils import Cache, DynamicCache, HybridCache
33+
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
3434
from ...generation import GenerationMixin
3535
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
3636
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1328,22 +1328,20 @@ def forward(
13281328
query_states = query_states.transpose(1, 2)
13291329

13301330
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
1331-
# Device of past layer may be different from current one
1332-
indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device)
13331331
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
1334-
if isinstance(past_key_value, HybridCache) and self.is_sliding:
1335-
max_length = past_key_value.sliding_window
1336-
indices = (
1337-
slice(0, max_length)
1338-
if cache_position.shape[0] > max_length
1339-
else cache_position.clamp(min=0, max=max_length - 1)
1340-
)
1332+
layer = past_key_value.layers[self.kv_shared_layer_index]
1333+
# Device of past layer may be different from current one
1334+
indices = cache_position.to(layer.keys.device)
1335+
# Sliding window cache layers might have smaller size (for full layers, we never go beyond)
1336+
if isinstance(layer, SlidingWindowLayer):
1337+
if cache_position.shape[0] > layer.get_max_cache_shape():
1338+
indices = slice(0, layer.get_max_cache_shape())
1339+
else:
1340+
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
13411341

13421342
# Device of past layer may be different from current one
1343-
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device)
1344-
value_states = (
1345-
past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device)
1346-
)
1343+
key_states = layer.keys[:, :, indices].to(query_states.device)
1344+
value_states = layer.values[:, :, indices].to(query_states.device)
13471345
else:
13481346
key_states = self.k_proj(hidden_states).view(hidden_shape)
13491347
key_states = self.k_norm(key_states)

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn.functional as F
2424

2525
from ...activations import ACT2FN
26-
from ...cache_utils import Cache, DynamicCache, HybridCache
26+
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
2727
from ...configuration_utils import PretrainedConfig, layer_type_validation
2828
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
2929
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1769,22 +1769,20 @@ def forward(
17691769
query_states = query_states.transpose(1, 2)
17701770

17711771
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
1772-
# Device of past layer may be different from current one
1773-
indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device)
17741772
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
1775-
if isinstance(past_key_value, HybridCache) and self.is_sliding:
1776-
max_length = past_key_value.sliding_window
1777-
indices = (
1778-
slice(0, max_length)
1779-
if cache_position.shape[0] > max_length
1780-
else cache_position.clamp(min=0, max=max_length - 1)
1781-
)
1773+
layer = past_key_value.layers[self.kv_shared_layer_index]
1774+
# Device of past layer may be different from current one
1775+
indices = cache_position.to(layer.keys.device)
1776+
# Sliding window cache layers might have smaller size (for full layers, we never go beyond)
1777+
if isinstance(layer, SlidingWindowLayer):
1778+
if cache_position.shape[0] > layer.get_max_cache_shape():
1779+
indices = slice(0, layer.get_max_cache_shape())
1780+
else:
1781+
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
17821782

17831783
# Device of past layer may be different from current one
1784-
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device)
1785-
value_states = (
1786-
past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device)
1787-
)
1784+
key_states = layer.keys[:, :, indices].to(query_states.device)
1785+
value_states = layer.values[:, :, indices].to(query_states.device)
17881786
else:
17891787
key_states = self.k_proj(hidden_states).view(hidden_shape)
17901788
key_states = self.k_norm(key_states)

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,52 @@ def test_eager_padding_matches_padding_free_with_position_ids(self):
151151
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
152152
pass
153153

154+
def test_generation_beyond_sliding_window_tiny_model(self):
155+
"""Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`.
156+
The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache
157+
and mask slicing logic is covered.
158+
"""
159+
config = Gemma3TextConfig.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM")
160+
config.attn_implementation = "eager"
161+
config.layer_types = ["full_attention", "sliding_attention"]
162+
config.sliding_window = 8
163+
config.max_position_embeddings = 128
164+
model = AutoModelForCausalLM.from_pretrained(
165+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM", config=config
166+
).to(torch_device)
167+
168+
input_len = 10
169+
input_ids = torch.tensor(
170+
[
171+
[42300, 241087, 255445, 81315, 193760, 184471, 67719, 98191, 210651, 124725],
172+
[102294, 205314, 226646, 62020, 60245, 68025, 251839, 114053, 4695, 175511],
173+
],
174+
device=torch_device,
175+
)
176+
attention_mask = torch.ones_like(input_ids).to(torch_device)
177+
with torch.no_grad():
178+
_ = model.generate(
179+
input_ids,
180+
attention_mask=attention_mask,
181+
max_new_tokens=1,
182+
do_sample=False,
183+
use_cache=True,
184+
cache_implementation="hybrid",
185+
)
186+
# 2 generations are needed to trigger https://github.com/huggingface/transformers/issues/39711
187+
# Since it requires model._cache to have been previously initialized
188+
output = model.generate(
189+
input_ids,
190+
attention_mask=attention_mask,
191+
max_new_tokens=5,
192+
do_sample=False,
193+
use_cache=True,
194+
cache_implementation="hybrid",
195+
)
196+
generated_sequences = output[:, input_len:].cpu()
197+
EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]])
198+
torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT)
199+
154200

155201
class Gemma3Vision2TextModelTester:
156202
def __init__(

tests/models/gemma3n/test_modeling_gemma3n.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ def test_contrastive_generate_low_memory(self):
431431
def test_dola_decoding_sample(self):
432432
pass
433433

434+
@pytest.mark.generate
435+
@unittest.skip("Gemma3n does not support QuantizedCache as it performs cache manipulation in the forward pass")
436+
def test_generate_with_quant_cache(self):
437+
pass
438+
434439

435440
class Gemma3nVision2TextModelTester:
436441
text_config = {"activation_sparsity_pattern": None}

0 commit comments

Comments
 (0)