Skip to content

Conversation

jscaldwell55
Copy link

Summary

Fixes #2856 - DTensor/torch.Tensor mixed type error in Llama4 LoRA fine-tuning

Problem

When using LoRA fine-tuning with LinearCrossEntropyLoss and custom_sharded_layers, users encounter a tensor type mismatch error:
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

This happens because:

  • LinearCrossEntropyLoss uses model.output for the final projection
  • LoRA configs typically set custom_sharded_layers = ['tok_embeddings'] without including 'output'
  • FSDP wraps only the layers listed in custom_sharded_layers as DTensors
  • This creates a mismatch when computing loss (DTensor hidden states × regular Tensor output weights)

Solution

Added validation that checks if LinearCrossEntropyLoss is used with custom_sharded_layers and ensures 'output' is included in the list. This provides a clear, actionable error message at setup time rather than a cryptic error during training.

Implementation

  • Created shared validation module recipes/validation.py to avoid code duplication
  • Added validation to both full_finetune_distributed.py and lora_finetune_distributed.py recipes
  • Validation is called in _setup_model before FSDP wrapping occurs
  • Added comprehensive unit tests covering various edge cases

Testing

  • Unit tests added in tests/recipes/test_validation.py
  • Tests cover: missing output, correct config, None/empty layers, disabled parallelism, non-LinearCrossEntropyLoss
  • No changes to existing functionality - only adds validation

Example Error Message

When misconfigured, users will now see:
ValueError: When using LinearCrossEntropyLoss with custom_sharded_layers, 'output' must be included to ensure tensor compatibility. Example: custom_sharded_layers = ['tok_embeddings', 'output'].

This guides users to the correct configuration immediately.

When using LinearCrossEntropyLoss with custom_sharded_layers in FSDP,
'output' must be included in the layer list to ensure tensor type
compatibility.

- Added shared validation module in recipes/validation.py
- Integrated validation into full_finetune_distributed and lora_finetune_distributed
- Added comprehensive unit tests
- Provides clear error message to guide users to correct configuration

Fixes pytorch#2856
Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2900

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@jscaldwell55
Copy link
Author

jscaldwell55 commented Aug 4, 2025

@nathan-az Hey Nathan, apologies for confusion. I had meant to revise current PR but created a new branch without thinking. Wanted to give you a little context for this. After looking into it, here's what I found:

  • The full fine-tuning config correctly includes custom_sharded_layers: ['tok_embeddings', 'output']
  • But, LoRA fine-tuning configs typically only include ['tok_embeddings']
  • This makes sense for LoRA's use case, but causes issues with LinearCrossEntropyLoss which needs the output layer for its projection

Key changes from the previous attempt:

  • No modifications to cross_entropy_loss.py
  • Validation logic in a shared module to avoid duplication
  • Error message that guides users to the fix
  • Comprehensive unit tests
  • Validation placed in _setup_model where both variables are in scope

I guess my "fix" here is just making sure users get a descriptive error message rather than runtime error during training. Particularly for LoRA users who might not realize they need to include 'output' when using LinearCrossEntropyLoss.

Thanks again for your review. Your notes were super helpful, and I learned some cool stuff working through this :)
If you think further revision is needed just let me know.

@nathan-az
Copy link
Collaborator

nathan-az commented Aug 4, 2025

Hey @jscaldwell55 - thanks again for your work here! 2 request if you have the time.
Could you confirm you're using the latest (or near latest) nightlies? There were some changes made to the model root resharding logic during FSDP in response to some recent PyTorch changes, and this error is reminiscent of those.

In addition, could you provide an updated traceback/error logs showing what happens when you run your config on main, either here or in the original issue? The reason I ask is that a lot of the loss logic was updated after that bug was reported, where we stopped extracting the weight directly, so the old logs don't help much now. It would be good to see where the error actually reports now.

It's not immediately obvious to me why having tok_embeddings sharded separately (i.e. in custom_sharded_layers) would necessitate also including output, but having updated logs may make it clearer.

@jscaldwell55
Copy link
Author

jscaldwell55 commented Aug 14, 2025

@nathan-az Got to work on this some this morning and completed the testing you requested. Here are findings:

Test Environment

  • PyTorch: 2.6.0+cu124 (latest available in Colab)
  • Torchtune: 0.6.1
  • Model Used: meta-llama/Llama-2-7b-hf with LoRA

Test Results

Single-GPU Testing

✅ Training runs successfully without DTensor errors in single-GPU mode with the problematic configuration:

loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
custom_sharded_layers: ['tok_embeddings']  # Without 'output'

However, single-GPU doesn't trigger the FSDP DTensor wrapping that causes the issue.

Multi-GPU Evidence
The original issue #2856 from June 29, 2025 shows the error still occurring with:

PyTorch 2.8.0.dev20250625
8xA100 GPUs
Same configuration causing the exact DTensor/Tensor mismatch

Key Finding: The error explicitly occurs at cross_entropy_loss.py:71:
pythonlogits = F.linear(hidden_chunk, weight) # DTensor × regular Tensor = Error

Analysis
The issue manifests only in multi-GPU FSDP scenarios where:

  • tok_embeddings in custom_sharded_layers → wrapped as DTensor
  • output NOT in custom_sharded_layers → remains regular Tensor
  • LinearCrossEntropyLoss uses output.weight for projection
  • Mixed tensor types cause the runtime error

Note on Multi-GPU Testing

I wasn't able to generate updated error logs from multi-GPU testing due to resource constraints (only have access to single GPU via Colab). The single-GPU test doesn't trigger the FSDP DTensor wrapping, so it runs without errors.

The most recent multi-GPU error logs I could find are from issue #2856, which show the error occurring at the same location in the code (cross_entropy_loss.py:71). While these logs are from PyTorch 2.8.0.dev20250625, the error mechanism appears unchanged based on:

The code structure in current main still has the same F.linear call
The FSDP wrapping logic with custom_sharded_layers works the same way
The loss function still extracts model.output.weight for projection

For next steps, would you like me to:

Try to find a multi-GPU environment for definitive testing?
Or proceed with the validation in this PR as a defensive measure given the evidence?

Please let me know what you're thinking re best path forward. Really appreciate your patience and feedback as I work through this :)

@nathan-az
Copy link
Collaborator

Hey mate, thanks for checking in on this. I have access to hardware again so ran some tests.

I ran two configs - the LLaMA 3.1 8B lora config, and the LLaMA 3.2 1B lora config. Both with custom_sharded_layers: ['tok_embeddings']

I was not able to replicate the issue with 8B, but I was with 1B. I believe the key difference in architecture between these two is that the 1B model uses tied weights between the token embeddings and final output projection. I think this is likely why we see issues when one is sharded and the other is not.

If this is correct, I think adding a layer of validation would be useful, but that it should factor this in.

@jscaldwell55
Copy link
Author

jscaldwell55 commented Aug 15, 2025

@nathan-az Awesome, thanks for running those tests!

This definitely makes sense to me - models with tied embeddings (where tok_embeddings.weight is the same tensor as output.weight) would have this issue when only one is in custom_sharded_layers. The 1B model uses weight tying for efficiency, while the 8B model has separate weights.

I'll update the validation to be more precise:

  1. Check if the model uses tied weights (when model.tok_embeddings.weight is model.output.weight)
  2. Only enforce the validation for tied-weight models using LinearCrossEntropyLoss
  3. Update the error message to explain the tied weights issue

This way we can avoid unnecessary restrictions on models that don't have this architectural constraint.

Quick question: Should I also check for the reverse case (where output is in custom_sharded_layers but tok_embeddings isn't)? Or would FSDP handle that differently?

@nathan-az
Copy link
Collaborator

No worries! So - one thing you reported that I haven't been able to replicate is that custom sharding of both also throws the type mismatch error when the layers are tied for me.

I'm not an expert in how FSDP (and FSDP2) work, but I see that the TiedLinear class's forward directly accesses the weight attribute from the underlying embedding via return self.linear(x, self.tied_module.weight). My best theory is that since the weight is being accessed directly, the FSDP hook to unshard is not running (i.e. by the time the matmul is done, I think the weights should just be a Tensor, not DTensor).

@ebsmothers @joecummings sorry for the direct mention - if anything is immediately obvious to either of you (e.g. does my theory have merit, is there an obvious + is there an obvious/clean fix), that would be great. No pressure to look deeper, I just don't have that much experience with how FSDP does sharding.

IMO - if sharding tied weights doesn't currently work, it's not a huge deal and we can just add validation ot confirm. By default, the transformer layers are still sharded when FSDP is used (this can be confirmed in the sharding code), and because these weights are tied, they consume half as much memory as they would by default.

@jscaldwell55
Copy link
Author

@nathan-az I'll hold off on implementing anything until we hear from the other reviewers. Happy to go with whatever makes the most sense to y'all.

Really appreciate you digging into this with me; I learn so much getting into these distributed training edge cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug Report: DTensor/torch.Tensor Mixed Type Error in Llama4 LoRA Fine-tuning
2 participants