@@ -171,16 +171,16 @@ def build(self, var_list):
171171 self .muon_velocities = {}
172172
173173 for var in var_list :
174- var .muon_path_id = self ._var_key (var )
174+ var ._muon_path_id = self ._var_key (var )
175175 if not self ._overwrite_variable_with_gradient (var ):
176- self .adam_momentums [var .muon_path_id ] = (
176+ self .adam_momentums [var ._muon_path_id ] = (
177177 self .add_variable_from_reference (
178178 reference_variable = var , name = "momentum"
179179 )
180180 )
181181 var ._muon_use_adam_flag = self ._should_use_adamw (var )
182182 if var ._muon_use_adam_flag :
183- self .adam_velocities [var .muon_path_id ] = (
183+ self .adam_velocities [var ._muon_path_id ] = (
184184 self .add_variable_from_reference (
185185 reference_variable = var , name = "velocity"
186186 )
@@ -196,7 +196,7 @@ def update_step(self, gradient, variable, learning_rate):
196196 self ._muon_update_step (gradient , variable , learning_rate )
197197
198198 def _muon_update_step (self , gradient , variable , lr ):
199- m = self .adam_momentums [variable .muon_path_id ]
199+ m = self .adam_momentums [variable ._muon_path_id ]
200200 self .assign_add (m , ops .add (gradient , m * (self .momentum - 1 )))
201201 if self .nesterov :
202202 g = ops .add (gradient , self .momentum * m )
@@ -221,8 +221,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
221221 ops .cast (self .adam_beta_2 , variable .dtype ), local_step
222222 )
223223
224- m = self .adam_momentums [variable .muon_path_id ]
225- v = self .adam_velocities [variable .muon_path_id ]
224+ m = self .adam_momentums [variable ._muon_path_id ]
225+ v = self .adam_velocities [variable ._muon_path_id ]
226226
227227 alpha = lr * ops .sqrt (1 - adam_beta_2_power ) / (1 - adam_beta_1_power )
228228
0 commit comments