-
Notifications
You must be signed in to change notification settings - Fork 311
Align Int4Tensor implementation details with the design of Float8Tensor #2687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jerryzh168/stack/10
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2687
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit f9695e4 with merge base d2e791b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1d84542
to
4874773
Compare
4874773
to
1beccb0
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
@classmethod | ||
def from_float( | ||
def to_int4( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why rename this? from_float
seems cleaner, maybe from_high_precision
might be even more clear. Same feedback for Float8Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to align with Float8Tensor, sure, can rename this to from_high_precision
for both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about from_float_hp
? this is a bot more precise than high_precision
I think
res = torch.ops.fbgemm.bf16i4bf16_rowwise( | ||
input_tensor, | ||
weight_tensor._data.contiguous(), | ||
weight_tensor.qdata.contiguous(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it expected that the tensors are not contiguous? if not, can we assert for this instead of calling contiguous
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the non-contiguous comes from the reshape ops like transpose, view I think, but the kernel will need these to be contiguous, I can try changing these to assert and do the contiguous operation in user side to see if it works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have expected the weights to be stored in a format aligned with what the kernel needs, without any need for just-in-time layout transforms. Does this match how the current code works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normally it is, but the weights also goes through some transformations like the ones we listed in test_moe_weight_reshape_ops
which makes weight / scale etc. non-contiguous I think, but I can try to do call contiguous in user code, that might be cleaner I think
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1beccb0
to
5f6306e
Compare
|
||
|
||
@register_quantize_module_handler(TestOnlyMoEQuantConfig) | ||
def moe_quant_fn(module, config: TestOnlyMoEQuantConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really confusing, could you share the result of print(model)
after this function has been applied?
if it's going to print model with parameters wrapped in Int4Tensor
, can we just wrap the parameters directly without all of these layers of abstraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is working around the fact that quantize_
needs to work on modules, IMO we should change quantize_ to handle this instead of working around? seems important for MoEs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the parameters are wrapped in Int4Tensor, this is just applying quantization to each of the moe weights: w1, w2 and w3
I can inline these for now. can follow up with how to have an API for weights + configs separately
@@ -177,3 +178,63 @@ def create_model_and_input_data( | |||
else: | |||
raise ValueError(f"Unknown model type: {model_type}") | |||
return model, input_data | |||
|
|||
|
|||
class Experts(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call it something like FeedForwardWithExperts
? Experts
is ambiguous
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is adapted from https://github.com/meta-llama/llama-models/blob/a9c89c471f793423afd4cc3ca8671d6e56fe64cb/models/llama4/moe.py#L22, how about renaming to LLama4Experts
to make it more specific
5f6306e
to
6bd3106
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
402bf2c
to
70aba27
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
70aba27
to
8d51aed
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
8d51aed
to
b2d49bd
Compare
b2d49bd
to
aee3dbb
Compare
7868bcf
to
ceac84c
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
aee3dbb
to
4a50bf7
Compare
4a50bf7
to
7a21719
Compare
7a21719
to
0040a5f
Compare
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
0040a5f
to
f9695e4
Compare
Stacked PRs:
optional_tensor_names
in TorchAOBaseTensor #2710Align Int4Tensor implementation details with the design of Float8Tensor
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]
Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Reviewers:
Subscribers:
Tasks:
Tags: