-
Notifications
You must be signed in to change notification settings - Fork 283
Add support for resharding for fbgemm configs #2387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2387
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 23c46f4 with merge base 6243040 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
6525c15
to
9e128c1
Compare
Why are these ops needed? is it for DCP? |
test/dtypes/test_fbgemm_fp8.py
Outdated
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) | ||
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1) | ||
self.assertTrue(cat_weight1.shape, (512, 128)) | ||
self.assertTrue(cat_weight2.shape, (256, 256)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also assert equality of bits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
torchao/dtypes/fbgemm_fp8_tensor.py
Outdated
data_to_scale_dim: the dim mapping from float8_data to scale, e.g. | ||
float8_data: (batch_size, output_channel, input_channel) | ||
scale: (batch_size, output_channel) (since it's per row quantization) | ||
data_to_scale_dim: {0: 0, 1: 1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This explanation isn't very helpful / I dont know what this is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit confusing, removed
) | ||
|
||
def _transpose_and_reshape(self): | ||
"""This is added for resharding support, since the resharding logic for the model we are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these these next two functions need to be methods or can they be implementations of the actual ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be methods, it's specific for the hack we are doing
assert len(self.shape) == 3, ( | ||
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}" | ||
) | ||
dim0, dim1, dim2 = self.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont understand this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is specific to the hack, we'll transpose the weight first and then quantize, so (dim0, dim2, dim1) is the original shape
we are restoring the shape to original shape to resharding here
bed8189
to
4e562e2
Compare
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat Reviewers: Subscribers: Tasks: Tags:
4e562e2
to
23c46f4
Compare
TODO: needs padding for cutlass kernels | ||
|
||
Tensor Attributes: | ||
float8_data: float8 raw data, dtype torchao.float8.config.e4m3_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to clarify, does this mean there is a dependency on torchao.float8.config.e4m3_dtype
? If so, I think the dependency should be refactored away to a common utility, it's not expected for that config to affect anything other than the torchao.float8
workflow.
Groupwise int4 weight only quantization | ||
Tensor Attributes: | ||
packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when is a weight a 3D tensor, and why is the batch dimension in here? could you share a specific example
Summary:
added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding.
In the future we should probably provide examples for using distributed checkpoint for resharding
Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat
Reviewers:
Subscribers:
Tasks:
Tags: