Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ csrc/flash_attn_v3/*_sm90.cu
csrc/flash_attn_v3/instantiations/
csrc/flashmask_v2/*_sm90.cu
csrc/flashmask_v2/instantiations/

flashmask/flash_mask.py
flashmask/flash_mask/flashmask_attention_v3/instantiations/
292 changes: 0 additions & 292 deletions flashmask/flash_mask/CMakeLists.txt

This file was deleted.

68 changes: 55 additions & 13 deletions flashmask/flash_mask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = []

# [BQW_CHANGE] 在 import 前先加载 flash_mask_pd_.so 并注册自定义算子
# Paddle CUDAExtension 生成 flash_mask_pd_.so,需要手动加载注册
import os
import paddle
# ============================================================
# FA3: C++/CUDA compiled extension (requires paddle + .so)
# ============================================================
_fa3_available = False
try:
import os as _os
import paddle

_curr_dir = os.path.dirname(os.path.abspath(__file__))
_parent_dir = os.path.dirname(_curr_dir)
_so_path = os.path.join(_parent_dir, "flash_mask_pd_.so")
_so_loaded = False

if os.path.exists(_so_path):
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(_so_path)
else:
print(f"[WARNING] flash_mask_pd_.so not found at {_so_path}, custom ops may not be available")
# 尝试从已安装的模块中加载
try:
import flash_mask as _flash_mask_module
_so_path = _flash_mask_module.__file__
if _so_path and _so_path.endswith('.so'):
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(_so_path)
_so_loaded = True
except Exception:
pass

from .flashmask_attention_v3.interface import flashmask_attention
# 如果还没加载,尝试从 build 目录加载
if not _so_loaded:
_curr_dir = _os.path.dirname(_os.path.abspath(__file__))
_parent_dir = _os.path.dirname(_curr_dir)
_possible_paths = [
_os.path.join(_parent_dir, "build", "flash_mask",
"lib.linux-x86_64-cpython-310", "flash_mask.so"),
_os.path.join(_parent_dir, "flash_mask.so"),
]
for _so_path in _possible_paths:
if _os.path.exists(_so_path):
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(_so_path)
_so_loaded = True
break

__all__ = ["flashmask_attention"]
if _so_loaded:
from .flashmask_attention_v3.interface import flashmask_attention as flashmask_attention_v3
__all__.append("flashmask_attention_v3")
_fa3_available = True
else:
print("[WARNING] flash_mask.so not found, FA3 custom ops not available")
except ImportError:
pass # paddle not installed, skip FA3

# ============================================================
# FA4: Pure Python + CUTLASS DSL (no .so needed)
# ============================================================
_fa4_available = False
try:
from .cute import flash_attention, flashmask_attention
__all__ += ["flash_attention", "flashmask_attention"]
_fa4_available = True
except ImportError:
pass # cute module not installed or dependencies missing

if not _fa3_available and not _fa4_available:
print("[WARNING] flash_mask: neither FA3 nor FA4 is available. "
"Check your installation.")
Loading