Skip to content

Conversation

wesleytruong
Copy link
Contributor

@wesleytruong wesleytruong commented Aug 29, 2025

Llama4 State Dict Adapter

To verify the correctness of my conversion, I truncated the Llama4 model by loading the HF model, popping all but the first transformer layer, modifying the config file,aving it, and modifying my torchtitan model_args in order to perform a forward pass this way as well. My KL divergence results using top_k=16, disabling HF qk_norm, and disabling TT load_balance_coeff were as follows:

Loss for test from_hf:
Mean: 1.2682261285590357e-07
Std: 2.1823700535605894e-07
Min: 1.1850400218407775e-13
Max: 1.2106198710171157e-06
Median: 4.4414797173431e-08

This KL divergence loss is not particularly low enough to absolutely guarantee correctness, so I dissected the model's hidden states within the transformer block to compare comparable hidden states. However, it falls within a tolerable range as could be caused by numerical imprecision. I previously tested converted HF->TT->HF on 2D sharded state dict and found a similar loss, but the sanity check (greedy decoding) still matched perfectly.

Although I wasn't able to find a configuration or tranformation that can reduce this KL divergence further, I found that even when properly aligning the hidden states by saving the hidden state at the beginning of TT's MOE input and copying to HF's MOE input during its forward pass, after only a single identical ffn layer + elementwise multiplication, that the mean of the hidden state shifts by 20.54%. Since the input had been exactly aligned and these operations properly isolated, this leaves me to believe that it must be due to numerical instability or imprecision due to TorchTitan and HuggingFace dtypes.

Addiitonally, semantically TorchTitan and HuggingFace Llama4 implementations are identical. The only difference is how TorchTitan uses a sparsely represented input to calculate its MOE forward pass efficiently, whereas HuggingFace uses a dense input + mask.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
@wwwjn
Copy link
Contributor

wwwjn commented Aug 29, 2025

properly aligning the hidden states by saving the hidden state at the beginning of TT's MOE input and copying to HF's MOE input during its forward pass, after only a single identical ffn layer + elementwise multiplication, that the mean of the hidden state shifts by 20.54%

If the mean value is small, the 20.54% shift can be very small as well.

popping all but the first transformer layer, modifying the config file,aving it, and modifying my torchtitan model_args in order to perform a forward pass this way as well.

The first transformer layer is dense layer or MoE layer? If possible, you could dump the imtermediate states within a transformer layer, eg after attention module before MoE module, to check which submodule caused the discrepancy

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

}

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: there are several items with value = None in self.from_hf_map. If we reverse the mapping here, how does to_hf_map look like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, when we reverse the mapping any key/val pair containing a None value would be dropped. I mainly just keep these mappings here to denote which layers may be unmapped between HF and torchtitan, but I can remove the vision layers since they won't be used.


if key in to_hf_map:
# do direct mapping
if key in "layers.{}.moe.experts.w2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note here why we need transpose() here to get correct shape?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some context: in torchtitan we define experts as [num_experts, output_dim, input_dim] for efficiency concern.

# handle splitting values
split_vals = value.chunk(2, dim=-1)
split_vals = [val.transpose(-1, -2) for val in split_vals]
for new_key, split_val in zip(self.combination_plan[key], split_vals):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the first half is w1 and the second half is w3? I was thinking we could even remove the self.combination_plan, and directly hardcode the split/concate and the corresponding torchtitan fqn in from_hf(). This is because we only have 2 titan fqns to split, not too much .Now we still need to hard code the abstract_key (eg, "language_model.model.layers.{}.feed_forward.experts.gate_up_proj").

@wesleytruong
Copy link
Contributor Author

wesleytruong commented Aug 29, 2025

properly aligning the hidden states by saving the hidden state at the beginning of TT's MOE input and copying to HF's MOE input during its forward pass, after only a single identical ffn layer + elementwise multiplication, that the mean of the hidden state shifts by 20.54%

If the mean value is small, the 20.54% shift can be very small as well.

Yes I agree that since the mean value is small (~e-5) therefore the actual shift is also small in magnitude. However, my argument is that we have shifted the probability distribution of our tensor by a significant proportion, 20.54%. This tensor that diverges is the input to the experts layer which plays a major role in our final distribution. If every value that we multiply in the experts layer differs by on average 20% then I think we can reason that this difference will propagate.

popping all but the first transformer layer, modifying the config file,aving it, and modifying my torchtitan model_args in order to perform a forward pass this way as well.

The first transformer layer is dense layer or MoE layer? If possible, you could dump the imtermediate states within a transformer layer, eg after attention module before MoE module, to check which submodule caused the discrepancy

I used the 17bx16E model as the base, which has MOE layer in every transformer layer.

Also yes, I dumped the intermediate states as you said after many of the important multiplication operations within the MOE layer. What I found is that even when I copy the hidden state after attention layer, the two hidden states diverge after only 2 floating point operations: MOE router FFN, and elementwise multiplication of router scores and original MOE hidden. I hypothesis the difference is due to how the router scores (range 0-1) and hidden states(mean value -2.9585115044028498e-05) multiply together and cause more imprecision in bfloat16 than float32.

@wwwjn
Copy link
Contributor

wwwjn commented Aug 29, 2025

Thanks for the detailed reply!

MOE router FFN,
Do you the mean the whole router? Or self.gate in the router (which is a Linear layer)

and elementwise multiplication of router scores and original MOE hidden
Do you mean routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)

I agree that numerics could shift due to precision differences. If the answer to previous 2 questions are yes, these 2 parts are not simply 2 floating point operation, it calculate scores of each experts for each token. Small early numerics different (egscores = self.gate(x), or expert_bias) might result in a token being routed to different experts

@wesleytruong
Copy link
Contributor Author

wesleytruong commented Aug 29, 2025

Thanks for the detailed reply!

MOE router FFN,
Do you the mean the whole router? Or self.gate in the router (which is a Linear layer)

and elementwise multiplication of router scores and original MOE hidden
Do you mean routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1)

I agree that numerics could shift due to precision differences. If the answer to previous 2 questions are yes, these 2 parts are not simply 2 floating point operation, it calculate scores of each experts for each token. Small early numerics different (egscores = self.gate(x), or expert_bias) might result in a token being routed to different experts

Sorry, yes I meant the nn.Linear gate in router.

Also in order to address the potentially different routing, I did:

  • copied the hidden state x from TorchTitan to HF before the router scores = self.gate(x)
  • I compared the router_scores tensors after HF and TT gate numerically and got:
    • HF: Mean:0.3794191777706146, Std Dev: 0.0902581587433815, Min:0.12090068310499191, Max:0.6247883439064026, Shape:torch.Size([100, 16])
    • TT: Mean:0.379419207572937, Std Dev: 0.0902581587433815, Min:0.12090068310499191, Max:0.6247883439064026, Shape:torch.Size([100, 16])
    • numerically I interpret this the router_scores being well-aligned
    • this alignment is important not due to the actual router assignment but since these values will be elementwise multiplied back with MOE hidden state according to router assignment
  • Disabled expert_bias in TorchTitan since HuggingFace doesn't use this
  • Compared the unpermuted discrete selected_expert_indices returned by torch.top_k() in both HF and TT and verified they are exact match
    • Verifies that each token is routed to the same experts in both TT and HF
  • Used top_k=16
    • This allows us to compare the hidden state input to the experts layer in TT and HF, as TT input is sparse and HF is dense + mask
    • This means if top_k != num_experts then torchtitan hidden state at this point will look like bs*top_k*seq_len x dim and HF hidden state will look like bs*num_experts*seq_len x dim

I think after these changes, this reduces the divergence to numeric instability

To clarify semantics this is my model architecture of this portion I'm comparing:

  • TransformerBlock
    • Attention
    • MOE layer # realigned hidden state at this input
      • Router (nn.Linear)
      • elementwise multiplication (router_scores, MOE input hidden_state)
      • Experts layer # compare input hidden state here
        • w1
        • w2
        • w3

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work! No big concern from me. I'll let @wwwjn stamp.


if key in to_hf_map:
# do direct mapping
if key in "layers.{}.moe.experts.w2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some context: in torchtitan we define experts as [num_experts, output_dim, input_dim] for efficiency concern.

… hard code combination plan and splitting for readability
@wesleytruong
Copy link
Contributor Author

@wwwjn I addressed the comments if you can take another look, thanks!

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very solid work with solid verification! Thank you for making this happen!

@wesleytruong wesleytruong merged commit 298bf48 into main Aug 30, 2025
4 checks passed
@tianyu-l tianyu-l deleted the llama4_hf_conversion branch August 31, 2025 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants