Skip to content

Add support for resharding for fbgemm configs #2387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

Summary:
added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding.

In the future we should probably provide examples for using distributed checkpoint for resharding

Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jun 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2387

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 23c46f4 with merge base 6243040 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 16, 2025
@jerryzh168 jerryzh168 requested a review from drisspg June 16, 2025 20:44
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 16, 2025
@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 2 times, most recently from 6525c15 to 9e128c1 Compare June 16, 2025 20:49
@drisspg
Copy link
Contributor

drisspg commented Jun 16, 2025

Why are these ops needed? is it for DCP?

cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
self.assertTrue(cat_weight1.shape, (512, 128))
self.assertTrue(cat_weight2.shape, (256, 256))
Copy link
Contributor

@drisspg drisspg Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also assert equality of bits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
float8_data: (batch_size, output_channel, input_channel)
scale: (batch_size, output_channel) (since it's per row quantization)
data_to_scale_dim: {0: 0, 1: 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This explanation isn't very helpful / I dont know what this is doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing, removed

)

def _transpose_and_reshape(self):
"""This is added for resharding support, since the resharding logic for the model we are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these these next two functions need to be methods or can they be implementations of the actual ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be methods, it's specific for the hack we are doing

assert len(self.shape) == 3, (
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}"
)
dim0, dim1, dim2 = self.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is specific to the hack, we'll transpose the weight first and then quantize, so (dim0, dim2, dim1) is the original shape

we are restoring the shape to original shape to resharding here

@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 2 times, most recently from bed8189 to 4e562e2 Compare June 17, 2025 01:22
Summary:
added transpose and cat op support, and also some custom transpose/reshape/unflatten support
for resharding.

In the future we should probably provide examples for using distributed checkpoint for resharding

Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose
python test/dtypes/test_fbgemm_int4.py -k test_cat
python test/dtypes/test_fbgemm_fp8.py -k test_transpose
python test/dtypes/test_fbgemm_fp8.py -k test_cat

Reviewers:

Subscribers:

Tasks:

Tags:
TODO: needs padding for cutlass kernels

Tensor Attributes:
float8_data: float8 raw data, dtype torchao.float8.config.e4m3_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify, does this mean there is a dependency on torchao.float8.config.e4m3_dtype? If so, I think the dependency should be refactored away to a common utility, it's not expected for that config to affect anything other than the torchao.float8 workflow.

Groupwise int4 weight only quantization
Tensor Attributes:
packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is a weight a 3D tensor, and why is the batch dimension in here? could you share a specific example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants