diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..6c2f5b0575 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -141,6 +141,46 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} + const int32_t d = threadIdx.x * VEC_WIDTH; + {%- if not weighted and vbe %} + for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) { + + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + + const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j); + const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1); + const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2); + const auto grad_offset_j3 = SHFL_SYNC(grad_offset, j + 3); + const auto grad_offset_j4 = SHFL_SYNC(grad_offset, j + 4); + const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5); + const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6); + const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { + const int32_t d = {{ d }}; + if (threadIdx.x * VEC_WIDTH < D) { + Vec4TAcc grad_out_vec0 = Vec4TAcc(&grad_output[0][grad_offset_j0 + d]); + Vec4TAcc grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc(); + grad_sum[vec].add_(grad_out_vec0); + grad_sum[vec].add_(grad_out_vec1); + grad_sum[vec].add_(grad_out_vec2); + grad_sum[vec].add_(grad_out_vec3); + grad_sum[vec].add_(grad_out_vec4); + grad_sum[vec].add_(grad_out_vec5); + grad_sum[vec].add_(grad_out_vec6); + grad_sum[vec].add_(grad_out_vec7); + } + } + + } + {%- else %} for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { {%- if nobag %} int32_t l_j = SHFL_SYNC(l, j); @@ -180,6 +220,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} } } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 1afb2943bb..8f190d04d2 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 186a9d529f..4732a7a0ec --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -1044,7 +1054,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1055,7 +1065,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( @@ -1195,7 +1205,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 6574bda45e..bbd62a8bbc --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -454,6 +458,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..20c055e917 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,6 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..d69d685136 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -808,7 +808,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2262,6 +2262,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2280,6 +2281,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2289,6 +2291,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, @@ -215,6 +225,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -233,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } @@ -259,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { @@ -471,7 +516,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..841b121018 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,7 @@ #include #endif #include - +#include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace { inline int get_device_sm_cnt_() { @@ -138,11 +138,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; +return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 35d2d87fa5..c1812123ec 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs,