Skip to content

NVFP4 dequantization#505

Open
aris134 wants to merge 1 commit intodevfrom
amartin/nvfp4-dequant
Open

NVFP4 dequantization#505
aris134 wants to merge 1 commit intodevfrom
amartin/nvfp4-dequant

Conversation

@aris134
Copy link
Copy Markdown

@aris134 aris134 commented Mar 25, 2026

Description

Fixes https://github.com/ROCm/frameworks-internal/issues/15998

Enable NVFP4 dequantization on AMD GPU (gfx950) and add unit test.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enable compilation of NVFP4 dequantization kernel for AMD GPU
  • Add unit test that verifies NVFP4 dequantization works on gfx950

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@aris134 aris134 self-assigned this Mar 26, 2026
@aris134 aris134 marked this pull request as ready for review March 26, 2026 13:16
ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err);

const float amax = 1.0f;
input.set_tensor_amax(amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_scale() instead?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for dequantization, the scale is needed

Copy link
Copy Markdown
Author

@aris134 aris134 Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to memory fault run-time error, whereas my current method (set_tensor_amax) works fine. Leaving as is for now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double-check. Quantization does not need amax, dequant should not have it either.

Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is based on PR#472. Not to review the same changes twice let's wait for that PR to merge

ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err);

const float amax = 1.0f;
input.set_tensor_amax(amax);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for dequantization, the scale is needed

@aris134 aris134 force-pushed the amartin/nvfp4-dequant branch from 2682291 to 1d0a70e Compare April 2, 2026 15:12
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label Apr 2, 2026
Comment on lines +160 to +178
#ifdef __HIP_PLATFORM_AMD__
static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x0101010101010101ULL, 0x0202020202020202ULL,
0x0404040404040404ULL, 0x0808080808080808ULL,
0x1010101010101010ULL, 0x2020202020202020ULL,
0x4040404040404040ULL, 0x8080808080808080ULL};
#else
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
#endif

// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
__device__ __forceinline__ float groupMax(float val,
#ifdef __HIP_PLATFORM_AMD__
uint64_t groupMask) {
#else
unsigned int groupMask) {
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes in this file are due to a merge error and should not be necessary.

Comment on lines +34 to +36
size_t divide_round_up(size_t x, size_t y) {
return (x + y - 1) / y;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this part of test_common.h?

Comment on lines +56 to +65
const uint8_t bits = static_cast<uint8_t>(dis(gen));

fp8e4m3 candidate;
std::memcpy(&candidate, &bits, sizeof(bits));

const float decoded = static_cast<float>(candidate);
if (std::isfinite(decoded)) {
scale_buffer[idx] = candidate;
break;
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of codes generating a valid fp8e4m3 are reused in the 2d scale as well, let's consolidate them to avoid maintaining duplicated copies

for (size_t block = 0; block < mathematical_blocks_per_row; ++block) {
const size_t idx = row * physical_row_stride + block;

while (true) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, is there a way to generate fp8e4m3 without using a while loop try? I understand with multiple-tryout, we will finally find a non-infinite fp8e4m3, but it's a little bit waste for the random seed and execution time. Fp8e4m3 is well documented, probably we can study which bit-patterns give non-infinite values?

std::memcpy(&candidate, &bits, sizeof(bits));

const float decoded = static_cast<float>(candidate);
if (std::isfinite(decoded)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scales also need to be non-negative, right?

Comment on lines +307 to +312
generate_1d_scales(host_scales_rowwise_1d.get(),
unpadded_blocks_Y,
unpadded_blocks_X,
scales_stride,
gen,
fp8_dis);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment?

Comment on lines +315 to +320
generate_1d_scales(host_scales_colwise_1d.get(),
unpadded_blocks_Y_t,
unpadded_blocks_X_t,
scales_stride_t,
gen,
fp8_dis);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment?

Comment on lines +43 to +45
const size_t mathematical_rows,
const size_t mathematical_blocks_per_row,
const size_t physical_row_stride,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reuse the existing names (unpadded_blocks_Y, unpadded_blocks_X, and scales_stride)

}

// Decode a single FP4 (E2M1) value from packed storage.
float get_fp4_value(const fp4e2m1* data, const size_t mathematical_idx) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the scale has the padding/alignment distinction, mathemtical (unpadded) idx vs padded index. I recall for rowwise /columnwise data, we don't have this padding issue? If so, we can mathematical_idx -> idx

Comment on lines +311 to +316
float *amax_gpu = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float)));
NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(),
sizeof(float), cudaMemcpyHostToDevice));

tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use from_cpu()

void Tensor::from_cpu() const {

tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}

void set_tensor_amax(float amax) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard our rocm specific changes by macro

constexpr size_t scale_tensor_alignment_X_colwise = 128;
#endif

static constexpr float E2M1_LUT[16] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm specific contents

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants