Skip to content

Commit 635a8ab

Browse files
authored
update FLASH_VER_LAST to include flash-attn==2.8.0.post2 (#1296)
* update FLASH_VER_LAST * update inequality * add package parsing * small formatting * formatting * string formatting * formatting * formatting * move import, use torch parser * linting error
1 parent 35bfb51 commit 635a8ab

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

xformers/ops/fmha/flash.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Iterable, List, Optional, Set, Tuple, Union
1111

1212
import torch
13+
from torch._vendor.packaging.version import parse as parse_version
1314

1415
from ..common import get_operator, register_operator
1516
from .attn_bias import (
@@ -72,15 +73,15 @@
7273
_C_flashattention = flash_attn.flash_attn_interface.flash_attn_gpu
7374

7475
FLASH_VERSION = flash_attn.__version__
75-
FLASH_VER_MIN = (2, 7, 1)
76-
FLASH_VER_LAST = (2, 8, 0) # last supported, inclusive
77-
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
76+
FLASH_VER_MIN = parse_version("2.7.1")
77+
FLASH_VER_LAST = parse_version("2.8.0.post2") # last supported, inclusive
78+
flash_ver_parsed = parse_version(FLASH_VERSION)
7879
if (
7980
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
8081
) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1":
8182
raise ImportError(
82-
f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])},"
83-
f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} "
83+
f"Requires Flash-Attention version >={FLASH_VER_MIN},"
84+
f"<={FLASH_VER_LAST} "
8485
f"but got {FLASH_VERSION}."
8586
)
8687
VARLEN_LSE_PACKED = True

0 commit comments

Comments
 (0)