Skip to content

Commit 39020ea

Browse files
committed
modify code .
1 parent 69e0f48 commit 39020ea

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

keras/src/backend/tensorflow/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
114114
def _prepare_var(v):
115115
new_v = v.value if isinstance(v, backend.Variable) else v
116116
new_v._muon_use_adam_flag = v._muon_use_adam_flag
117-
new_v.muon_path_id = v.muon_path_id
117+
new_v._muon_path_id = v._muon_path_id
118118
return new_v
119119

120120
trainable_variables = [_prepare_var(v) for v in trainable_variables]

keras/src/optimizers/muon.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,16 @@ def build(self, var_list):
171171
self.muon_velocities = {}
172172

173173
for var in var_list:
174-
var.muon_path_id = self._var_key(var)
174+
var._muon_path_id = self._var_key(var)
175175
if not self._overwrite_variable_with_gradient(var):
176-
self.adam_momentums[var.muon_path_id] = (
176+
self.adam_momentums[var._muon_path_id] = (
177177
self.add_variable_from_reference(
178178
reference_variable=var, name="momentum"
179179
)
180180
)
181181
var._muon_use_adam_flag = self._should_use_adamw(var)
182182
if var._muon_use_adam_flag:
183-
self.adam_velocities[var.muon_path_id] = (
183+
self.adam_velocities[var._muon_path_id] = (
184184
self.add_variable_from_reference(
185185
reference_variable=var, name="velocity"
186186
)
@@ -196,7 +196,7 @@ def update_step(self, gradient, variable, learning_rate):
196196
self._muon_update_step(gradient, variable, learning_rate)
197197

198198
def _muon_update_step(self, gradient, variable, lr):
199-
m = self.adam_momentums[variable.muon_path_id]
199+
m = self.adam_momentums[variable._muon_path_id]
200200
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
201201
if self.nesterov:
202202
g = ops.add(gradient, self.momentum * m)
@@ -221,8 +221,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
221221
ops.cast(self.adam_beta_2, variable.dtype), local_step
222222
)
223223

224-
m = self.adam_momentums[variable.muon_path_id]
225-
v = self.adam_velocities[variable.muon_path_id]
224+
m = self.adam_momentums[variable._muon_path_id]
225+
v = self.adam_velocities[variable._muon_path_id]
226226

227227
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
228228

0 commit comments

Comments
 (0)