Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ if(HIP_FOUND)
list(APPEND HIP_CXX_FLAGS -mf16c)
list(APPEND HIP_CXX_FLAGS -mfma)
list(APPEND HIP_CXX_FLAGS -std=c++20)
list(APPEND HIP_CXX_FLAGS -g)
list(APPEND HIP_CXX_FLAGS -ggdb)

# list(APPEND HIP_CXX_FLAGS -Wa,-adhln)
#list(APPEND HIP_CXX_FLAGS -adhln)
list(APPEND HIP_CXX_FLAGS -save-temps)
list(APPEND HIP_CXX_FLAGS -fverbose-asm)


set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS})
# Ask hcc to generate device code during compilation so we can use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,46 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
? sorted_indice_weights[segment_start + sl_j]
: 0.0;
{%- endif %}
const int32_t d = threadIdx.x * VEC_WIDTH;
{%- if not weighted and vbe %}
for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) {

{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}

const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j);
const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1);
const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2);
const auto grad_offset_j3 = SHFL_SYNC(grad_offset, j + 3);
const auto grad_offset_j4 = SHFL_SYNC(grad_offset, j + 4);
const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5);
const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6);
const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
const int32_t d = {{ d }};
if (threadIdx.x * VEC_WIDTH < D) {
Vec4TAcc<grad_t> grad_out_vec0 = Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j0 + d]);
Vec4TAcc<grad_t> grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc<grad_t>();
grad_sum[vec].add_(grad_out_vec0);
grad_sum[vec].add_(grad_out_vec1);
grad_sum[vec].add_(grad_out_vec2);
grad_sum[vec].add_(grad_out_vec3);
grad_sum[vec].add_(grad_out_vec4);
grad_sum[vec].add_(grad_out_vec5);
grad_sum[vec].add_(grad_out_vec6);
grad_sum[vec].add_(grad_out_vec7);
}
}

}
{%- else %}
for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) {
{%- if nobag %}
int32_t l_j = SHFL_SYNC(l, j);
Expand Down Expand Up @@ -180,6 +220,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
{%- endif %}
}
}
{%- endif %}
}
{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}

Expand Down
15 changes: 14 additions & 1 deletion fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand Down Expand Up @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);

CUDA_DEVICE_GUARD(dev_weights);

#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
// offsets = [B x T + 1]
Expand Down
16 changes: 13 additions & 3 deletions fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}(

CUDA_DEVICE_GUARD(dev_weights);

#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

{%- if nobag and not is_index_select %}
auto max_D = D;
{%- endif %}
Expand Down Expand Up @@ -1044,7 +1054,7 @@ Tensor {{ embedding_cuda_op }}(

// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize;
const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes(
&num_cta_per_row_groups,
[&] (int num_groups) {
Expand All @@ -1055,7 +1065,7 @@ Tensor {{ embedding_cuda_op }}(
);

const int32_t cta_per_row_grid_size = std::min(
div_round_up(total_unique_indices, kMaxThreads),
div_round_up(total_unique_indices, (kMaxThreads/4)),
get_max_thread_blocks_());

FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -1195,7 +1205,7 @@ Tensor {{ embedding_cuda_op }}(
const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half
|| dev_weights.scalar_type() == at::ScalarType::Float;

if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna())
if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna())
{
constexpr int segments_per_workgroup = 4;
{%- for kDimSize in [64, 128, 160, 192, 256] %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
const int32_t segment_length_mod = segment_length & length_mask;

cache_t grad_acc[dword_per_row];
int32_t infos[segment_unroll];
uint32_t infos[segment_unroll];
grad_t grad_data[dword_per_row * segment_prefetch];
emb_t emb_data[dword_per_row];
float indice_weights[segment_unroll];
Expand Down Expand Up @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[0] >> info_B_num_bits;
bag_index = infos[0] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

{%- if nobag %}
Expand All @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[1] >> info_B_num_bits;
bag_index = infos[1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
if constexpr (!weighted){
#pragma unroll
Expand All @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j] >> info_B_num_bits;
bag_index = infos[j] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j + 1] >> info_B_num_bits;
bag_index = infos[j + 1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
}

Expand Down Expand Up @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j] >> info_B_num_bits;
bag_index = infos[j] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j + 1] >> info_B_num_bits;
bag_index = infos[j + 1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
}

Expand All @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[0] >> info_B_num_bits;
bag_index = infos[0] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

{%- if nobag %}
Expand All @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[1] >> info_B_num_bits;
bag_index = infos[1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

if constexpr (!weighted) {
Expand All @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j] >> info_B_num_bits;
bag_index = infos[j] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j + 1] >> info_B_num_bits;
bag_index = infos[j + 1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
}

Expand All @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j] >> info_B_num_bits;
bag_index = infos[j] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
table_index = infos[j + 1] >> info_B_num_bits;
bag_index = infos[j + 1] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
}

Expand All @@ -420,7 +420,7 @@ L_tail_grad_acc:
table_index = infos[0] >> info_B_num_bits;
bag_index = infos[0] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[0], lane_id);
Expand All @@ -441,7 +441,7 @@ L_tail_grad_acc:
table_index = infos[0] >> info_B_num_bits;
bag_index = infos[0] & info_B_mask;
{%- endif %}
load_row_per_warp<grad_t, embedding_dim, int32_t>::run(
load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[0], lane_id, indice_weights[0]);
Expand All @@ -452,7 +452,7 @@ L_tail_grad_acc:
}

// load the old emb weight data
load_row_per_warp<emb_t, embedding_dim, int64_t>::run(
load_row_per_warp<emb_t, embedding_dim, index_t>::run(
&emb_data[0], emb_idx, p_emb_table, lane_id);
optimizer_t optimizer(opt_karg);
optimizer.template update<dword_per_row, segment_split>(grad_acc, emb_data, emb_idx);
Expand Down
14 changes: 14 additions & 0 deletions fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include "fbgemm_gpu/utils/dispatch_macros.h"
{%- endif %}

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}

{%- if not is_index_select %}
////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
Expand Down Expand Up @@ -454,6 +458,16 @@ batch_index_select_dim0_codegen_forward_cuda(

CUDA_DEVICE_GUARD(dev_weights);

#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

{%- if not nobag %}
int32_t T = D_offsets.numel() - 1;
{%- else %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ class {{ autograd_func }} :
TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value.");
const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value();
const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE];
const auto mixed_D = aux_bool[IDX_MIXED_D];
{%- endif %}

// Default values for Dynamo tracing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def __init__( # noqa C901
assert (
self.pooling_mode != PoolingMode.NONE
), "Mixed dimension tables only supported for pooling tables."

self.mixed_D = mixed_D
assert all(
cd == compute_devices[0] for cd in compute_devices
), "Heterogenous compute_devices are NOT supported!"
Expand Down Expand Up @@ -2262,6 +2262,7 @@ def forward( # noqa: C901
row_counter,
iter_int,
self.max_counter.item(),
mixed_D=self.mixed_D,
),
)
elif self._used_rowwise_adagrad_with_global_weight_decay:
Expand All @@ -2280,6 +2281,7 @@ def forward( # noqa: C901
# `Optional[Tensor]` but got `Union[Module, Tensor]`.
prev_iter_dev=self.prev_iter_dev,
gwd_lower_bound=self.gwd_lower_bound,
mixed_D=self.mixed_D,
),
)
else:
Expand All @@ -2289,6 +2291,7 @@ def forward( # noqa: C901
common_args,
self.optimizer_args,
momentum1,
mixed_D=self.mixed_D,
),
)

Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
namespace fbgemm_gpu::rocm {

[[nodiscard]] inline bool is_supported_cdna() {
const std::set<std::string> supported_archs{"gfx942", "gfx90a"};
const std::set<std::string> supported_archs{"gfx942", "gfx90a", "gfx950"};
int device_id = 0;
HIP_CHECK(hipGetDevice(&device_id));
hipDeviceProp_t dev_props;
Expand Down
Loading
Loading