diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 27c26f7a8..9a23efc25 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -214,7 +214,8 @@ list(APPEND transformer_engine_cuda_sources fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast.cu @@ -238,8 +239,7 @@ if(USE_CUDA) fused_attn/fused_attn_fp8.cu fused_attn/utils.cu swizzle/swizzle.cu - swizzle/swizzle_block_scaling.cu - recipe/nvfp4.cu) + swizzle/swizzle_block_scaling.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 37766f5ce..8eded3622 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,7 +87,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: if IS_HIP_EXTENSION: - return False, "ROCm TE currently not supporting NVFP4" + return True, "" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, ""