Skip to content

Align Int4Tensor implementation details with the design of Float8Tensor #2687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/10
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 5, 2025

Stacked PRs:


Align Int4Tensor implementation details with the design of Float8Tensor

Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2687

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit f9695e4 with merge base d2e791b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1d84542 to 4874773 Compare August 5, 2025 03:25
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2025
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 5, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 18:39
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 4874773 to 1beccb0 Compare August 5, 2025 18:39
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 18:39
@classmethod
def from_float(
def to_int4(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename this? from_float seems cleaner, maybe from_high_precision might be even more clear. Same feedback for Float8Tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to align with Float8Tensor, sure, can rename this to from_high_precision for both

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about from_float_hp? this is a bot more precise than high_precision I think

res = torch.ops.fbgemm.bf16i4bf16_rowwise(
input_tensor,
weight_tensor._data.contiguous(),
weight_tensor.qdata.contiguous(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected that the tensors are not contiguous? if not, can we assert for this instead of calling contiguous?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the non-contiguous comes from the reshape ops like transpose, view I think, but the kernel will need these to be contiguous, I can try changing these to assert and do the contiguous operation in user side to see if it works

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected the weights to be stored in a format aligned with what the kernel needs, without any need for just-in-time layout transforms. Does this match how the current code works?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally it is, but the weights also goes through some transformations like the ones we listed in test_moe_weight_reshape_ops which makes weight / scale etc. non-contiguous I think, but I can try to do call contiguous in user code, that might be cleaner I think

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 23:30
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1beccb0 to 5f6306e Compare August 5, 2025 23:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 23:30


@register_quantize_module_handler(TestOnlyMoEQuantConfig)
def moe_quant_fn(module, config: TestOnlyMoEQuantConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really confusing, could you share the result of print(model) after this function has been applied?

if it's going to print model with parameters wrapped in Int4Tensor, can we just wrap the parameters directly without all of these layers of abstraction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is working around the fact that quantize_ needs to work on modules, IMO we should change quantize_ to handle this instead of working around? seems important for MoEs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the parameters are wrapped in Int4Tensor, this is just applying quantization to each of the moe weights: w1, w2 and w3

I can inline these for now. can follow up with how to have an API for weights + configs separately

@@ -177,3 +178,63 @@ def create_model_and_input_data(
else:
raise ValueError(f"Unknown model type: {model_type}")
return model, input_data


class Experts(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call it something like FeedForwardWithExperts? Experts is ambiguous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is adapted from https://github.com/meta-llama/llama-models/blob/a9c89c471f793423afd4cc3ca8671d6e56fe64cb/models/llama4/moe.py#L22, how about renaming to LLama4Experts to make it more specific

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 01:07
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 5f6306e to 6bd3106 Compare August 6, 2025 01:08
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 22:10
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 402bf2c to 70aba27 Compare August 6, 2025 22:10
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 6, 2025 22:10
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 22:16
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 70aba27 to 8d51aed Compare August 6, 2025 22:16
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 6, 2025 22:16
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 23:27
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 8d51aed to b2d49bd Compare August 6, 2025 23:27
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 6, 2025 23:27
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 02:57
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from b2d49bd to aee3dbb Compare August 7, 2025 02:57
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 02:57
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/10 branch from 7868bcf to ceac84c Compare August 7, 2025 02:58
jerryzh168 added a commit that referenced this pull request Aug 7, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from aee3dbb to 4a50bf7 Compare August 7, 2025 02:58
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 03:37
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 4a50bf7 to 7a21719 Compare August 7, 2025 03:37
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 03:37
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 03:51
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 7a21719 to 0040a5f Compare August 7, 2025 03:51
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 03:51
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 04:29
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 0040a5f to f9695e4 Compare August 7, 2025 04:29
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 04:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants