Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c0c3c34
BUILD CUDA 13
johnnynunez Oct 9, 2025
3bd4670
Update action.yml
johnnynunez Oct 9, 2025
f0e2384
Update action.yml
johnnynunez Oct 15, 2025
37eed3a
Update flash-attention hash
johnnynunez Oct 15, 2025
db622e4
fix error cccl
johnnynunez Oct 21, 2025
920ade3
Update requirements.txt
johnnynunez Oct 21, 2025
5beaf3f
Merge branch 'facebookresearch:main' into main
johnnynunez Oct 21, 2025
c2407a6
fix error pytorch 2.9.0 in CI
johnnynunez Oct 21, 2025
3e2e11e
Update linters_reusable.yml
johnnynunez Oct 28, 2025
bfb2271
Update CUDA toolkit and Python versions in workflow
johnnynunez Oct 28, 2025
d434cb3
Update Python version from 3.9 to 3.10
johnnynunez Oct 28, 2025
ea44071
upstream
johnnynunez Oct 28, 2025
dbe25a2
lint
johnnynunez Oct 28, 2025
b819a23
Add use-github-cache option to CUDA setup action
johnnynunez Oct 28, 2025
a62a9a9
Update cuda-toolkit action to use N-Storm fork
johnnynunez Oct 29, 2025
f732af6
Fix CUDA architecture list format in setup.py
johnnynunez Oct 29, 2025
40872c0
Update TORCH_CUDA_ARCH_LIST for toolkit versioning
johnnynunez Oct 29, 2025
6be38b9
try fix windows
johnnynunez Oct 30, 2025
6ba993e
Merge remote-tracking branch 'origin/main'
johnnynunez Oct 30, 2025
6f6e99e
avoid compile fa3 windows with cu130
johnnynunez Oct 30, 2025
d5acb15
Update CUDA version from 13.0.1 to 13.0.2
johnnynunez Oct 30, 2025
3b8fcb2
Modify CUDA version check for Windows platform
johnnynunez Oct 30, 2025
7281508
Update CUDA version from 13.0.2 to 13.0.1
johnnynunez Oct 30, 2025
fc6b421
Update setup.py
johnnynunez Oct 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/actions/setup-build-cuda/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ runs:
print(sys.version)
cushort = "${{ inputs.toolkit_short_version }}"
# Version uploaded to pypi (rather than PyTorch s3)
TORCH_CUDA_DEFAULT = "128" # since pytorch 2.8.0
TORCH_CUDA_DEFAULT = "130" # since pytorch 2.9.0
# https://github.com/Jimver/cuda-toolkit/blob/master/src/links/linux-links.ts
full_version, install_script = {
"130": ("13.0.1", "https://developer.download.nvidia.com/compute/cuda/13.0.1/local_installers/cuda_13.0.1_580.82.07_linux.run"),
"129": ("12.9.0", "https://developer.download.nvidia.com/compute/cuda/12.9.1/local_installers/cuda_12.9.1_575.57.08_linux.run"),
"128": ("12.8.1", "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_570.124.06_linux.run"),
# (Build with nvcc 12.8 on linux even when building for 12.6 to avoid seg fault in Flash3 build)
Expand All @@ -52,7 +53,8 @@ runs:
- name: Install cuda
if: runner.os == 'Windows' && inputs.toolkit_type == 'cuda'
id: cuda-toolkit
uses: Jimver/[email protected]
# Using N-Storm fork until https://github.com/Jimver/cuda-toolkit/issues/395 is resolved
uses: N-Storm/[email protected]
with:
cuda: ${{ steps.cuda_info.outputs.CUDA_VERSION }}
method: network
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/setup-env-build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ runs:

CONDA_INSTALL_CMD = "micromamba create python=${{ inputs.python }} zlib pip ninja ccache=4.8 -c conda-forge -q -y"

conda_env_key = CONDA_INSTALL_CMD + "[cu129][v2]"
conda_env_key = CONDA_INSTALL_CMD + "[cu130][v2]"
for file in sorted(glob.glob("requirement*.txt")):
conda_env_key += f"\n########## {file}\n"
conda_env_key += Path(file).read_text()
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ on:
jobs:
repo:
uses: ./.github/workflows/linters_reusable.yml

4 changes: 2 additions & 2 deletions .github/workflows/linters_reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Run pre-script
if: ${{ inputs.pre-script }}
run: ${{ inputs.pre-script }}
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/rocm_ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: rocm-ci

on:
on:
pull_request:
types: [labeled, synchronize, reopened]
workflow_dispatch: {}
Expand Down Expand Up @@ -43,23 +43,23 @@ jobs:

export GIT_BRANCH=${GITHUB_BASE_REF:-${GITHUB_REF#refs/heads/}}
echo GIT_BRANCH = $GIT_BRANCH

export ROCM_PATH=/opt/rocm
echo ROCM_PATH = $ROCM_PATH

hipcc --version
rocm-smi
rocminfo | grep "gfx"

- name: Setup build env
run: |
conda create -n xformers python=3.11
export PATH=/opt/conda/envs/xformers/bin:$PATH
python -VV

python -m pip install -U torch --index-url=https://download.pytorch.org/whl/rocm6.2
python -c "import torch; print(f'PyTorch version {torch.__version__}')"

python -m pip install ninja scipy pytest pytest-html

- name: Pre-build clean
Expand All @@ -72,16 +72,16 @@ jobs:
run: |
export PATH=/opt/conda/envs/xformers/bin:$PATH
export MAX_JOBS=20

python -m pip install -e ./_xformers --verbose
python -m xformers.info

- name: Run python tests
run: |
export PATH=/opt/conda/envs/xformers/bin:$PATH

python -m pytest --html=test_mem_eff_attention.html --self-contained-html -rpfs ./_xformers/tests/test_mem_eff_attention.py

- name: Archive logs
if: '!cancelled()'
uses: actions/upload-artifact@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/rocm_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ jobs:
environ = os.environ

# All builds are python-version agnostic,
# and built with python 3.9
PYTHON_VERSION = "3.9"
# and built with python 3.10
PYTHON_VERSION = "3.10"
# NOTE: Don't forget to update `upload_pt`'s matrix
# when changing the CUDA/ROCM versions below!
CU_VERSIONS = ['126', '128', '129']
CU_VERSIONS = ['126', '129', '130']
ROCM_VERSIONS = ['6.4']

include = []
for os in ['8-core-ubuntu', 'windows-8-core']:
for torch_version in ['2.8.0']:
for torch_version in ['2.9.0']:
# CUDA builds
for cuda_short_version in CU_VERSIONS:
if cuda_short_version < "124" and "windows" in os:
Expand Down Expand Up @@ -88,7 +88,7 @@ jobs:
uses: ./.github/workflows/wheels_upload_pip.yml
with:
twine_username: __token__
filter: "*torch2.8.0+cu128*"
filter: "*torch2.9.0+cu130*"
execute: ${{ github.repository == 'facebookresearch/xformers' && github.event_name != 'pull_request' }}
secrets:
twine_password: ${{ secrets.PYPI_TOKEN }}
Expand All @@ -100,13 +100,13 @@ jobs:
matrix:
suffix:
- cu126
- cu128
- cu129
- cu130
- rocm6.4
uses: ./.github/workflows/wheels_upload_s3.yml
with:
aws_role: "arn:aws:iam::749337293305:role/pytorch_bot_uploader_role"
s3_path: s3://pytorch/whl/${{ matrix.suffix }}/
aws_s3_cp_extra_args: --acl public-read
filter: "*torch2.8.0+${{ matrix.suffix }}*"
filter: "*torch2.9.0+${{ matrix.suffix }}*"
execute: ${{ github.repository == 'facebookresearch/xformers' && github.ref_type == 'tag' }}
8 changes: 6 additions & 2 deletions .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ jobs:
run:
shell: bash
steps:
- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 120
- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 120 && fromJSON(inputs.toolkit_short_version) < 130
run: |
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 9.0a" >> ${GITHUB_ENV}
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 8.0 9.0a" >> ${GITHUB_ENV}

- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 130
run: |
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 8.0 9.0a 10.0a 10.3a 11.0a 12.0a 12.1a" >> ${GITHUB_ENV}

- if: runner.os == 'Windows'
run: git config --system core.longpaths true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/win-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ jobs:
uses: ./.github/actions/setup-build-cuda
with:
toolkit_type: "cuda"
toolkit_short_version: "128"
python: "3.9"
toolkit_short_version: "130"
python: "3.10"

- name: Remove internal code
run: |
Expand All @@ -73,7 +73,7 @@ jobs:

- name: Install build dependencies
run: |
$PY -m pip install wheel setuptools ninja torch==2.8.0 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu126
$PY -m pip install wheel setuptools ninja torch==2.9.0 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu130
git config --global --add safe.directory "*"
$PY -c "import torch; print('torch', torch.__version__)"
$PY -c "import torch; print('torch.cuda', torch.version.cuda)"
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,3 @@ xformers/csrc/attention/hip_fmha/instances/*_hip.h
xformers/csrc/attention/hip_decoder/*.cu
xformers/csrc/attention/hip_decoder/*.hip
xformers/csrc/attention/hip_decoder/*_hip.h


2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example requirement, can be anything that pip knows
# install with `pip install -r requirements.txt`, and make sure that CI does the same
torch >= 2.8
torch >= 2.9
numpy
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int):
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version >= 1208:
if cuda_version >= 1300:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;11.0;12.0"
elif cuda_version >= 1208:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;12.0"
elif cuda_version >= 1108:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0"
Expand Down Expand Up @@ -281,9 +283,15 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int):
return []
if cuda_version < 1203:
return []
if ((sys.platform == "win32" or platform.system() == "Windows")
and cuda_version >= 1300):
return []
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
if archs_list is None:
if torch.cuda.get_device_capability("cuda") != (9, 0):
if torch.cuda.get_device_capability("cuda") != (
9,
0,
) and torch.cuda.get_device_capability("cuda") != (8, 0):
return []
archs_list = "8.0 9.0a"
nvcc_archs_flags = []
Expand Down
4 changes: 1 addition & 3 deletions tests/readme_test_on_rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

2. verify testing for generic fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_forward
#> pytest tests/test_mem_eff_attention.py::test_forward

3. verify testing for decoder fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_decoder
#> pytest tests/test_mem_eff_attention.py::test_splitk_decoder


2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 1271 files
2 changes: 1 addition & 1 deletion third_party/flash-attention
Submodule flash-attention updated 51 files
+33 −0 .github/workflows/pre-commit.yaml
+1 −0 .gitignore
+27 −0 .pre-commit-config.yaml
+9 −8 benchmarks/benchmark_attn.py
+1 −1 csrc/cutlass
+5 −1 csrc/flash_attn_ck/mha_bwd.cpp
+4 −0 csrc/flash_attn_ck/mha_fwd.cpp
+4 −0 csrc/flash_attn_ck/mha_varlen_fwd.cpp
+4 −0 flash_attn/cute/.flake8
+9 −1 flash_attn/cute/__init__.py
+11 −6 flash_attn/cute/ampere_helpers.py
+71 −0 flash_attn/cute/barrier.py
+268 −0 flash_attn/cute/benchmark.py
+688 −0 flash_attn/cute/benchmark_mask_mod.py
+331 −185 flash_attn/cute/blackwell_helpers.py
+38 −41 flash_attn/cute/block_info.py
+592 −0 flash_attn/cute/block_sparsity.py
+340 −0 flash_attn/cute/copy_utils.py
+124 −0 flash_attn/cute/cute_dsl_utils.py
+424 −301 flash_attn/cute/flash_bwd.py
+534 −314 flash_attn/cute/flash_bwd_postprocess.py
+211 −137 flash_attn/cute/flash_bwd_preprocess.py
+2,363 −0 flash_attn/cute/flash_bwd_sm100.py
+825 −975 flash_attn/cute/flash_bwd_sm90.py
+1,247 −502 flash_attn/cute/flash_fwd.py
+668 −281 flash_attn/cute/flash_fwd_sm100.py
+68 −20 flash_attn/cute/hopper_helpers.py
+694 −108 flash_attn/cute/interface.py
+325 −175 flash_attn/cute/mask.py
+285 −0 flash_attn/cute/mask_definitions.py
+4 −2 flash_attn/cute/mma_sm100_desc.py
+10 −4 flash_attn/cute/named_barrier.py
+217 −23 flash_attn/cute/pipeline.py
+1 −1 flash_attn/cute/pyproject.toml
+32 −9 flash_attn/cute/seqlen_info.py
+213 −65 flash_attn/cute/softmax.py
+115 −3 flash_attn/cute/tile_scheduler.py
+146 −34 flash_attn/cute/utils.py
+16 −6 flash_attn/flash_attn_interface.py
+1 −1 flash_attn/flash_attn_triton_amd/bwd_prefill_split.py
+3 −2 flash_attn/flash_attn_triton_amd/fwd_prefill.py
+8 −14 flash_attn/flash_attn_triton_amd/interface_fa.py
+4 −5 flash_attn/flash_attn_triton_amd/utils.py
+16 −4 hopper/flash_api_stable.cpp
+1 −1 hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp
+8 −1 hopper/setup.py
+20 −2 setup.py
+424 −117 tests/cute/test_flash_attn.py
+298 −0 tests/cute/test_flash_attn_varlen.py
+467 −0 tests/cute/test_mask_mod.py
+499 −0 tests/cute/test_score_mod.py
3 changes: 1 addition & 2 deletions xformers/benchmarks/readme_benchmark_on_rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

3. Benchmark for decoder fmha inference on ROCM

#> python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
#> python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py

4. Other Benchmarks for fmha inference on ROCM

#> python xformers/benchmarks/benchmark_attn_decoding.py
#> python xformers/benchmarks/benchmark_mem_eff_attention_mqa.py

24 changes: 12 additions & 12 deletions xformers/csrc/attention/hip_decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES LINKER_LANGUAGE
set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES HIP_ARCHITECTURES ${GPU_TARGETS})

target_compile_options(${exe_name} PUBLIC
target_compile_options(${exe_name} PUBLIC
-fno-gpu-rdc
$<$<CONFIG:Debug>:
--save-temps
>
)

target_compile_options(${splitk_exe_name} PUBLIC
target_compile_options(${splitk_exe_name} PUBLIC
-fno-gpu-rdc
$<$<CONFIG:Debug>:
--save-temps
Expand All @@ -52,13 +52,13 @@ target_compile_options(${splitk_exe_name} PUBLIC
>
)

target_include_directories(${exe_name} PUBLIC
target_include_directories(${exe_name} PUBLIC
${ck_include} # ck includes
${torch_include} # aten includes
${torch_include}/torch/csrc/api/include # torch includes
)

target_include_directories(${splitk_exe_name} PUBLIC
target_include_directories(${splitk_exe_name} PUBLIC
${ck_include} # ck includes
${torch_include} # aten includes
${torch_include}/torch/csrc/api/include # torch includes
Expand Down Expand Up @@ -93,28 +93,28 @@ target_link_libraries(${splitk_exe_name} PUBLIC
amdhip64
)

target_compile_definitions(${exe_name} PUBLIC
target_compile_definitions(${exe_name} PUBLIC
ATTN_FWD_DECODER_MAIN=1
GLIBCXX_USE_CXX11_ABI=1
__HIP_PLATFORM_HCC__=1
USE_ROCM=1
)

target_compile_definitions(${splitk_exe_name} PUBLIC
target_compile_definitions(${splitk_exe_name} PUBLIC
ATTN_FWD_SPLITK_DECODER_MAIN=1
GLIBCXX_USE_CXX11_ABI=1
__HIP_PLATFORM_HCC__=1
USE_ROCM=1
)

include(CMakePrintHelpers)
cmake_print_properties(TARGETS ${exe_name} ${splitk_exe_name} PROPERTIES
LINK_LIBRARIES
LINK_DIRECTORIES
INCLUDE_DIRECTORIES
COMPILE_DEFINITIONS
cmake_print_properties(TARGETS ${exe_name} ${splitk_exe_name} PROPERTIES
LINK_LIBRARIES
LINK_DIRECTORIES
INCLUDE_DIRECTORIES
COMPILE_DEFINITIONS
COMPILE_OPTIONS
SOURCES
HIP_ARCHITECTURES)

rocm_install(TARGETS ${exe_name} ${splitk_exe_name})
rocm_install(TARGETS ${exe_name} ${splitk_exe_name})
18 changes: 8 additions & 10 deletions xformers/csrc/attention/hip_fmha/GENERATE_INSTANCES.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@

# Instances generator

The instances generator is a simple python tool used to generate several hundred of instances (.cpp files) and their references (.h files).
Without this tool, manually writing those instances and references will be very laborious and easy to get wrong.
The instances generated by this scripts are divided into three categories visible from the scripts:
The instances generator is a simple python tool used to generate several hundred of instances (.cpp files) and their references (.h files).
Without this tool, manually writing those instances and references will be very laborious and easy to get wrong.

The instances generated by this scripts are divided into three categories visible from the scripts:
* Infer -- which refers to instances for calling inference-only kernels
* Forward -- which refers to instances for calling training forward kernels
* Backward -- which refers to instances for calling training backward kernels
The instance generator is for being used by the HIP fmha developers themselves. It is not supposed to be used by the xformers users for
building xformers, since for xformers users, the instances are already well prepared as part of the xformers codes.

The instance generator is for being used by the HIP fmha developers themselves. It is not supposed to be used by the xformers users for
building xformers, since for xformers users, the instances are already well prepared as part of the xformers codes.

## how to use instance generator

Expand All @@ -21,13 +21,11 @@
```
* To generate reduced instances (when headdim256 is not required)

```
```
#> python xformers/csrc/attention/hip_fmha/generate_instances.py --ignore-hd256
```
* More options except for `--ignore-hd256` could be added to suppport further customization in generating instances as required

## where the instances files are located
The instances files and references files are always located under a folder `instances/` that is located under the same directory
as the file `generate_instances.py` itself


Loading
Loading