Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append("-DNVTE_ENABLE_ROCSHMEM=ON")

else:
cmake_flags.append("-DUSE_ROCM=OFF")
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
cmake_flags.extend(("-DUSE_ROCM=OFF", "-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)))

if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ namespace dequantize_kernel {
#ifdef __HIP_PLATFORM_AMD__
#include "rocm_dequantize_mxfp8.cuh"
#else
constexpr size_t CHUNK_DIM_Y = 128;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these from the original upstream? Or are these from the rocm adjusted code?

constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFFERS_NUM = 2;

constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128

constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(ITERATIONS >= 1);

template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/common/util/ptx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {

__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#ifndef __HIP_PLATFORM_AMD__
constexpr bool is_blackwell = false;
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint16_t out;
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ struct Numeric_Traits<fp8e4m3> {
#endif
};

#if !defined(__HIP_DEVICE_COMPILE__)
#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__)
template <bool FNUZ>
struct Numeric_Traits_fp8e4m3: public Numeric_Traits<fp8e4m3> {
static constexpr double maxNorm = FNUZ ? 240 : 448;
Expand All @@ -1055,7 +1055,7 @@ struct Quantized_Limits {
static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
static constexpr float emax = 1 << max_unbiased_exponent;
static constexpr float emax_rcp = 1.0 / emax;
#if !defined(__HIP_DEVICE_COMPILE__)
#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__)
static constexpr struct {
operator float() const {
if (std::is_same<T, fp8e4m3>::value) {
Expand All @@ -1068,10 +1068,10 @@ struct Quantized_Limits {
} max_norm = {};
// dummy value for kernels host path compilation
static constexpr float max_norm_rcp = std::numeric_limits<float>::signaling_NaN();
#else // !defined(__HIP_DEVICE_COMPILE__)
#else // defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__)
static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
static constexpr float max_norm_rcp = 1.0 / max_norm;
#endif // !defined(__HIP_DEVICE_COMPILE__)
#endif // defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__)
};

} // namespace transformer_engine
Expand Down
Loading