Skip to content

Commit 43f136e

Browse files
committed
Raise DeprecationWarning if head_first passed
1 parent 3a7ecbf commit 43f136e

File tree

26 files changed

+53
-78
lines changed

26 files changed

+53
-78
lines changed

fla/ops/attn/parallel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -713,15 +712,15 @@ def parallel_attn(
713712
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
714713
"""
715714
if head_first:
716-
warnings.warn(
715+
raise DeprecationWarning(
717716
"head_first is deprecated and will be removed in a future version. "
718717
"Please use head_first=False for now instead."
719718
)
720719
q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v))
721720
if g is not None:
722721
g = rearrange(g, 'b h t ... -> b t h ...')
723722
if not head_first and q.shape[1] < q.shape[2]:
724-
warnings.warn(
723+
raise DeprecationWarning(
725724
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
726725
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
727726
"when head_first=False was specified. "

fla/ops/delta_rule/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -280,13 +279,13 @@ def chunk_delta_rule(
280279
assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
281280

282281
if head_first:
283-
warnings.warn(
282+
raise DeprecationWarning(
284283
"head_first is deprecated and will be removed in a future version. "
285284
"Please use head_first=False for now instead."
286285
)
287286
q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta))
288287
if not head_first and q.shape[1] < q.shape[2]:
289-
warnings.warn(
288+
raise DeprecationWarning(
290289
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
291290
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
292291
"when head_first=False was specified. "

fla/ops/delta_rule/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -514,13 +513,13 @@ def fused_recurrent_delta_rule(
514513
>>> assert ht.allclose(ht_var)
515514
"""
516515
if head_first:
517-
warnings.warn(
516+
raise DeprecationWarning(
518517
"head_first is deprecated and will be removed in a future version. "
519518
"Please use head_first=False for now instead."
520519
)
521520
q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta))
522521
if not head_first and q.shape[1] < q.shape[2]:
523-
warnings.warn(
522+
raise DeprecationWarning(
524523
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
525524
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
526525
"when head_first=False was specified. "

fla/ops/forgetting_attn/parallel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -49,13 +48,13 @@ def parallel_forgetting_attn(
4948
if cu_seqlens is not None:
5049
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
5150
if head_first:
52-
warnings.warn(
51+
raise DeprecationWarning(
5352
"head_first is deprecated and will be removed in a future version. "
5453
"Please use head_first=False for now instead."
5554
)
5655
q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
5756
if not head_first and q.shape[1] < q.shape[2]:
58-
warnings.warn(
57+
raise DeprecationWarning(
5958
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
6059
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
6160
"when head_first=False was specified. "

fla/ops/gated_delta_rule/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -313,13 +312,13 @@ def chunk_gated_delta_rule(
313312
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
314313

315314
if head_first:
316-
warnings.warn(
315+
raise DeprecationWarning(
317316
"head_first is deprecated and will be removed in a future version. "
318317
"Please use head_first=False for now instead."
319318
)
320319
q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g))
321320
if not head_first and q.shape[1] < q.shape[2]:
322-
warnings.warn(
321+
raise DeprecationWarning(
323322
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
324323
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
325324
"when head_first=False was specified. "

fla/ops/gated_delta_rule/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -283,13 +282,13 @@ def fused_recurrent_gated_delta_rule(
283282
>>> assert ht.allclose(ht_var)
284283
"""
285284
if head_first:
286-
warnings.warn(
285+
raise DeprecationWarning(
287286
"head_first is deprecated and will be removed in a future version. "
288287
"Please use head_first=False for now instead."
289288
)
290289
q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g))
291290
if not head_first and q.shape[1] < q.shape[2]:
292-
warnings.warn(
291+
raise DeprecationWarning(
293292
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
294293
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
295294
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/dplr/chunk.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -318,20 +317,20 @@ def chunk_dplr_delta_rule(
318317
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319318
"""
320319
if head_first:
321-
warnings.warn(
320+
raise DeprecationWarning(
322321
"head_first is deprecated and will be removed in a future version. "
323322
"Please use head_first=False for now instead."
324323
)
325324
q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
326325
if not head_first and q.shape[1] < q.shape[2]:
327-
warnings.warn(
326+
raise DeprecationWarning(
328327
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
329328
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
330329
"when head_first=False was specified. "
331330
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
332331
)
333332
if q.dtype == torch.float32:
334-
warnings.warn(
333+
raise DeprecationWarning(
335334
"""ChunkDeltaRuleFunction does not support float32. Please use bfloat16.
336335
If you want to use float32, please solve the issue by yourself."""
337336
)

fla/ops/generalized_delta_rule/dplr/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -249,13 +248,13 @@ def fused_recurrent_dplr_delta_rule(
249248
Default: `False`.
250249
"""
251250
if head_first:
252-
warnings.warn(
251+
raise DeprecationWarning(
253252
"head_first is deprecated and will be removed in a future version. "
254253
"Please use head_first=False for now instead."
255254
)
256255
q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
257256
if not head_first and q.shape[1] < q.shape[2]:
258-
warnings.warn(
257+
raise DeprecationWarning(
259258
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
260259
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
261260
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/iplr/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -462,13 +461,13 @@ def chunk_iplr_delta_rule(
462461
assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
463462

464463
if head_first:
465-
warnings.warn(
464+
raise DeprecationWarning(
466465
"head_first is deprecated and will be removed in a future version. "
467466
"Please use head_first=False for now instead."
468467
)
469468
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
470469
if not head_first and q.shape[1] < q.shape[2]:
471-
warnings.warn(
470+
raise DeprecationWarning(
472471
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
473472
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
474473
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/iplr/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -427,13 +426,13 @@ def fused_recurrent_iplr_delta_rule(
427426
428427
"""
429428
if head_first:
430-
warnings.warn(
429+
raise DeprecationWarning(
431430
"head_first is deprecated and will be removed in a future version. "
432431
"Please use head_first=False for now instead."
433432
)
434433
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
435434
if not head_first and q.shape[1] < q.shape[2]:
436-
warnings.warn(
435+
raise DeprecationWarning(
437436
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
438437
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
439438
"when head_first=False was specified. "

0 commit comments

Comments
 (0)