-
Notifications
You must be signed in to change notification settings - Fork 498
[HF] Llama4 Text State Dict Adapter #1662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0aedea4
to
5cdd9f1
Compare
5cdd9f1
to
89bbee3
Compare
If the mean value is small, the 20.54% shift can be very small as well.
The first transformer layer is dense layer or MoE layer? If possible, you could dump the imtermediate states within a transformer layer, eg after attention module before MoE module, to check which submodule caused the discrepancy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
} | ||
|
||
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
to_hf_map = {v: k for k, v in self.from_hf_map.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: there are several items with value = None in self.from_hf_map. If we reverse the mapping here, how does to_hf_map
look like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, when we reverse the mapping any key/val pair containing a None value would be dropped. I mainly just keep these mappings here to denote which layers may be unmapped between HF and torchtitan, but I can remove the vision layers since they won't be used.
|
||
if key in to_hf_map: | ||
# do direct mapping | ||
if key in "layers.{}.moe.experts.w2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note here why we need transpose() here to get correct shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some context: in torchtitan we define experts as [num_experts, output_dim, input_dim]
for efficiency concern.
# handle splitting values | ||
split_vals = value.chunk(2, dim=-1) | ||
split_vals = [val.transpose(-1, -2) for val in split_vals] | ||
for new_key, split_val in zip(self.combination_plan[key], split_vals): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the first half is w1 and the second half is w3? I was thinking we could even remove the self.combination_plan, and directly hardcode the split/concate and the corresponding torchtitan fqn in from_hf(). This is because we only have 2 titan fqns to split, not too much .Now we still need to hard code the abstract_key (eg, "language_model.model.layers.{}.feed_forward.experts.gate_up_proj").
Yes I agree that since the mean value is small (~e-5) therefore the actual shift is also small in magnitude. However, my argument is that we have shifted the probability distribution of our tensor by a significant proportion, 20.54%. This tensor that diverges is the input to the experts layer which plays a major role in our final distribution. If every value that we multiply in the experts layer differs by on average 20% then I think we can reason that this difference will propagate.
I used the 17bx16E model as the base, which has MOE layer in every transformer layer. Also yes, I dumped the intermediate states as you said after many of the important multiplication operations within the MOE layer. What I found is that even when I copy the hidden state after attention layer, the two hidden states diverge after only 2 floating point operations: MOE router FFN, and elementwise multiplication of router scores and original MOE hidden. I hypothesis the difference is due to how the router scores (range 0-1) and hidden states(mean value -2.9585115044028498e-05) multiply together and cause more imprecision in bfloat16 than float32. |
Thanks for the detailed reply!
I agree that numerics could shift due to precision differences. If the answer to previous 2 questions are yes, these 2 parts are not simply 2 floating point operation, it calculate scores of each experts for each token. Small early numerics different (eg |
Sorry, yes I meant the nn.Linear gate in router. Also in order to address the potentially different routing, I did:
I think after these changes, this reduces the divergence to numeric instability To clarify semantics this is my model architecture of this portion I'm comparing:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work! No big concern from me. I'll let @wwwjn stamp.
|
||
if key in to_hf_map: | ||
# do direct mapping | ||
if key in "layers.{}.moe.experts.w2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some context: in torchtitan we define experts as [num_experts, output_dim, input_dim]
for efficiency concern.
1ffff30
to
097ec46
Compare
… hard code combination plan and splitting for readability
097ec46
to
15c8335
Compare
@wwwjn I addressed the comments if you can take another look, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very solid work with solid verification! Thank you for making this happen!
Llama4 State Dict Adapter
To verify the correctness of my conversion, I truncated the Llama4 model by loading the HF model, popping all but the first transformer layer, modifying the config file,aving it, and modifying my
torchtitan
model_args in order to perform a forward pass this way as well. My KL divergence results usingtop_k=16
, disabling HFqk_norm
, and disabling TTload_balance_coeff
were as follows:This KL divergence loss is not particularly low enough to absolutely guarantee correctness, so I dissected the model's hidden states within the transformer block to compare comparable hidden states. However, it falls within a tolerable range as could be caused by numerical imprecision. I previously tested converted HF->TT->HF on 2D sharded state dict and found a similar loss, but the sanity check (greedy decoding) still matched perfectly.
Although I wasn't able to find a configuration or tranformation that can reduce this KL divergence further, I found that even when properly aligning the hidden states by saving the hidden state at the beginning of TT's MOE input and copying to HF's MOE input during its forward pass, after only a single identical ffn layer + elementwise multiplication, that the mean of the hidden state shifts by 20.54%. Since the input had been exactly aligned and these operations properly isolated, this leaves me to believe that it must be due to numerical instability or imprecision due to TorchTitan and HuggingFace dtypes.
Addiitonally, semantically TorchTitan and HuggingFace Llama4 implementations are identical. The only difference is how TorchTitan uses a sparsely represented input to calculate its MOE forward pass efficiently, whereas HuggingFace uses a dense input + mask.