From 6818c250654cb75ff2b715a7e611776261a8ae2f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 11:57:27 +0000 Subject: [PATCH 01/13] Add gfx950 build support + fp16 fix + index type fix --- fbgemm_gpu/cmake/Hip.cmake | 8 ++++++++ .../embedding_backward_split_template.cu | 2 +- ..._backward_split_device_kernel_template.hip | 2 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 20 ++++++++++++++++++- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 2 +- 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 186a9d529f..7bc4427355 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1195,7 +1195,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..5acc61382e 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..c96da01063 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -215,6 +215,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -471,7 +489,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 35d2d87fa5..c1812123ec 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 9355c411dd91eef9807000130a472cb406deae09 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 13:16:41 +0000 Subject: [PATCH 02/13] Change int64_t to index_t as template parameters in load_raw_per_warp --- .../rocm/embedding_backward_split_device_kernel_template.hip | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 5acc61382e..d5841d6e00 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); From 97610e35ea235509c0d60b97a62536ad12a2d06b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 14:39:22 +0000 Subject: [PATCH 03/13] Implement llvm fp16 buffer load for gfx950 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c96da01063..4b33fd1422 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, From f1cef5b1d22d593122848b71e85c657d123b99ad Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:23:47 +0000 Subject: [PATCH 04/13] Fix c-style half to float cast --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 4b33fd1422..238a83440a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -261,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } From 9576ab295169e5667a315f557c1e7bb8a782a3a8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:24:29 +0000 Subject: [PATCH 05/13] Patch 256 half stores --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 238a83440a..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -294,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { From 2dc021b3cec13e7daa4dcde81ade9b599791b242 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Fri, 8 Aug 2025 05:02:58 +0000 Subject: [PATCH 06/13] cta_per_row workgroup optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7bc4427355..6e6796f0ef 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1044,7 +1044,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1055,7 +1055,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/2)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From 55cb57e3083a59b3cc10eef88f3dde1f2f7df987 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 11 Aug 2025 21:06:48 +0000 Subject: [PATCH 07/13] Added mi350 guards --- ...ding_backward_split_indice_weights_template.cu | 15 ++++++++++++++- .../backward/embedding_backward_split_template.cu | 10 ++++++++++ .../forward/embedding_forward_split_template.cu | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 1afb2943bb..8f190d04d2 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 6e6796f0ef..a09f7d3886 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 6574bda45e..bbd62a8bbc --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -454,6 +458,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} From c1f444e676942c4e09e32d39479dddccb3024869 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 15:09:39 +0000 Subject: [PATCH 08/13] Fix index overflow in row load --- ..._backward_split_device_kernel_template.hip | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d5841d6e00..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); From 6088a661cababd9c48244f027eb76cebd6541af6 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:13:09 +0000 Subject: [PATCH 09/13] cta_per_row workgroup reduce by 4 optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index a09f7d3886..fd97aea59f 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1054,7 +1054,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1065,7 +1065,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/2)), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From 659f76636537449f61d6bffffc628855124f0e73 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 13 Aug 2025 13:21:38 +0000 Subject: [PATCH 10/13] Fix mixed_D frontend to backend connection --- .../training/backward/embedding_backward_split_template.cu | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 1 + .../split_table_batched_embeddings_ops_training.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fd97aea59f..4732a7a0ec 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1205,7 +1205,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..20c055e917 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,6 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..d69d685136 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -808,7 +808,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2262,6 +2262,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2280,6 +2281,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2289,6 +2291,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) From e18ac1d210e41ad5c7fe128a3f9ec4c595f20984 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:39:49 +0000 Subject: [PATCH 11/13] warpReduction DPP version --- fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..841b121018 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,7 @@ #include #endif #include - +#include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace { inline int get_device_sm_cnt_() { @@ -138,11 +138,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; +return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { From 781d039ae274d2537ed326f115d4c945a90a205c Mon Sep 17 00:00:00 2001 From: zhimding Date: Tue, 19 Aug 2025 06:25:00 +0000 Subject: [PATCH 12/13] apply unroll on vbe and not weighted kernel --- ..._backward_split_device_kernel_template.cuh | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..753c4ff0f4 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -141,6 +141,37 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} + const int32_t d = threadIdx.x * VEC_WIDTH; + {%- if not weighted and vbe %} + for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) { + const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j); + const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1); + const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2); + const auto grad_offset_j3 = SHFL_SYNC(grad_offset, j + 3); + const auto grad_offset_j4 = SHFL_SYNC(grad_offset, j + 4); + const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5); + const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6); + const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7); + if (threadIdx.x * VEC_WIDTH < D) { + Vec4TAcc grad_out_vec0 = Vec4TAcc(&grad_output[0][grad_offset_j0 + d]); + Vec4TAcc grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc(); + grad_sum[0].add_(grad_out_vec0); + grad_sum[0].add_(grad_out_vec1); + grad_sum[0].add_(grad_out_vec2); + grad_sum[0].add_(grad_out_vec3); + grad_sum[0].add_(grad_out_vec4); + grad_sum[0].add_(grad_out_vec5); + grad_sum[0].add_(grad_out_vec6); + grad_sum[0].add_(grad_out_vec7); + } + } + {%- else %} for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { {%- if nobag %} int32_t l_j = SHFL_SYNC(l, j); @@ -180,6 +211,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} } } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} From 436aacda583f80f62e3deccb84dfb7c8850f2dc2 Mon Sep 17 00:00:00 2001 From: zhimding Date: Tue, 19 Aug 2025 15:29:03 +0000 Subject: [PATCH 13/13] add vec for loop --- ..._backward_split_device_kernel_template.cuh | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 753c4ff0f4..6c2f5b0575 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -144,6 +144,9 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( const int32_t d = threadIdx.x * VEC_WIDTH; {%- if not weighted and vbe %} for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) { + + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j); const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1); const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2); @@ -152,24 +155,30 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5); const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6); const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7); - if (threadIdx.x * VEC_WIDTH < D) { - Vec4TAcc grad_out_vec0 = Vec4TAcc(&grad_output[0][grad_offset_j0 + d]); - Vec4TAcc grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc(); - Vec4TAcc grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc(); - grad_sum[0].add_(grad_out_vec0); - grad_sum[0].add_(grad_out_vec1); - grad_sum[0].add_(grad_out_vec2); - grad_sum[0].add_(grad_out_vec3); - grad_sum[0].add_(grad_out_vec4); - grad_sum[0].add_(grad_out_vec5); - grad_sum[0].add_(grad_out_vec6); - grad_sum[0].add_(grad_out_vec7); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { + const int32_t d = {{ d }}; + if (threadIdx.x * VEC_WIDTH < D) { + Vec4TAcc grad_out_vec0 = Vec4TAcc(&grad_output[0][grad_offset_j0 + d]); + Vec4TAcc grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc(); + grad_sum[vec].add_(grad_out_vec0); + grad_sum[vec].add_(grad_out_vec1); + grad_sum[vec].add_(grad_out_vec2); + grad_sum[vec].add_(grad_out_vec3); + grad_sum[vec].add_(grad_out_vec4); + grad_sum[vec].add_(grad_out_vec5); + grad_sum[vec].add_(grad_out_vec6); + grad_sum[vec].add_(grad_out_vec7); + } } + } {%- else %} for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) {