Skip to content

Conversation

@acutkosky
Copy link

Smoothed Muon Updates (1% improvement, includes PR 124 changes, runs PR 124 and 128 as parallel baselines)

This submission adds a small EMA filter to the Muon update:

original_muon_update =  NS(EMA(gradients)))
final_update = EMA(original_muon_update)
param = param - lr * final_update

Various notions of update smoothing are common in optimization literature; this submission tests the most minimal version I could think of that can be easily "tuned" to be close to a baseline of no-smoothing.

The EMA weight is rather small, starting at 0.5 and decaying to 0.2 over training.

Other changes:

  • Decay to 0.01 * peak lr, rather than 0.0.
  • Increase muon lr to 0.03.
  • Decrease iterations to 5610.

See the README.md file for detailed more info and stats.

Overall final improvement on my machine is 23.52 min / 23.72 min = 99.1% of PR124's time and marginally faster than PR128's time (23.52 min vs 23.59 min). PR128's validation loss was not sufficiently low in my replication however (see stats below).

Note that the times on my machine are slightly slower than those reported by the previous baselines.
I don't know why this is, but my final reported time is slightly faster than the reported baseline time, so I believe the speed improvement should be robust.

Simple stats:

I ran 80 trials for baselines and ablations. For the update smoothing change, I run four sets of only 40 runs each to check that the p-value has a reasonable chance of being small after a moderate number of runs.
That said, while the p-value has a good chance of being <0.01, after 40 trials, it's not an extremely high chance. I think it is pretty likely to be small after 80 runs. I don't know what the standards should be for making reproducibility easy here. I did test a run with 5630 iters, which I think will be extremely reliable in this regard.

Baselines:

PR124:

Processed 80 files.
val loss:
mean: 	2.919458
std:  	0.000737
val loss 99% confidence interval: (2.919241 - 2.919676)
val_loss t-test p=0.000000 (small means <2.92)
train time (minutes): mean=23.7152, std=0.1746
train time 99% confidence interval: (23.6637 - 23.7667)

PR128:

Processed 80 files.
val loss:
mean: 	2.919946
std:  	0.000839
val loss 99% confidence interval: (2.919698 - 2.920193)
val_loss t-test p=0.282210 (small means <2.92)
train time (minutes): mean=23.5897, std=0.1843
train time 99% confidence interval: (23.5353 - 23.6441)

This p-value is pretty high. I'm not sure what's wrong and I haven't investigated. The method itself seems reasonable and somewhat similar to the one proposed here.

Update smoothing data

with EMA update smoothing (4 replicates of 40 runs each to ensure p-value has reasonable chance of being small):

Replicate 1:
Processed 40 files.
val loss:
mean: 	2.919700
std:  	0.000786
val loss 99% confidence interval: (2.919364 - 2.920036)
val_loss t-test p=0.010251 (small means <2.92)
train time (minutes): mean=23.4630, std=0.1956
train time 99% confidence interval: (23.3793 - 23.5466)

Replicate 2:
Processed 40 files.
val loss:
mean: 	2.919769
std:  	0.000719
val loss 99% confidence interval: (2.919462 - 2.920076)
val_loss t-test p=0.024538 (small means <2.92)
train time (minutes): mean=23.5620, std=0.1834
train time 99% confidence interval: (23.4835 - 23.6404)

Replicate 3:
Processed 40 files.
val loss:
mean: 	2.919879
std:  	0.000684
val loss 99% confidence interval: (2.919587 - 2.920172)
val_loss t-test p=0.135301 (small means <2.92)
train time (minutes): mean=23.5111, std=0.1905
train time 99% confidence interval: (23.4296 - 23.5925)

Replicate 4:
Processed 40 files.
val loss:
mean: 	2.919653
std:  	0.000880
val loss 99% confidence interval: (2.919277 - 2.920030)
val_loss t-test p=0.008553 (small means <2.92)
train time (minutes): mean=23.5281, std=0.1638
train time 99% confidence interval: (23.4580 - 23.5981)

So, one replicate has a high p-value (0.135), two runs are very close to 0.01 (0.0103 and 0.0086), and one run is moderate value (0.0245). If we group into two sets of 80 runs, then the p-values are 0.0125 and 0.00025.

Full stats over all 160 replicates

Processed 160 files.
val loss:
mean: 	2.919750
std:  	0.000768
val loss 99% confidence interval: (2.919592 - 2.919909)
val_loss t-test p=0.000032 (small means <2.92)
train time (minutes): mean=23.5160, std=0.1855
train time 99% confidence interval: (23.4778 - 23.5542)

Shorter run for larger p-value:

To get a more reliable p-value, I increased iterations to 5630 (so 60 less than PR124). I also restored the muon lr to 0.025. After 40 runs, this yields:

Processed 40 files.
val loss:
mean: 	2.919379
std:  	0.000693
val loss 99% confidence interval: (2.919083 - 2.919675)
val_loss t-test p=0.000001 (small means <2.92)
train time (minutes): mean=23.5919, std=0.1811
train time 99% confidence interval: (23.5144 - 23.6693)

so, a slightly slower run, but much higher confidence.

Simple Ablation

increase iters to 5940, remove update smoothing, keep other changes the same:

Processed 80 files.
val loss:
mean: 	2.920077
std:  	0.000808
val loss 99% confidence interval: (2.919839 - 2.920316)
val_loss t-test p=0.802983 (small means <2.92)
train time (minutes): mean=23.60, std=0.19
train time 99% confidence interval: (23.55 - 23.66)

So, seems a little slower and doesn't hit the baseline. Not necessarily conclusive (better lr tuning might fix it), but at least this is suggestive that smoothing is helpful.

Pytorch/CUDA info

as copied from output file:

Running Python 3.13.5 (main, Jul 23 2025, 00:37:22) [Clang 20.1.4 ]
Running PyTorch 2.8.0+cu128 compiled for CUDA 12.8
Tue Sep 16 02:32:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8 

@YouJiacheng
Copy link
Contributor

oh cool.
EMA on updates is similar EMA on weights: it has a LR-decay-like effect (see https://arxiv.org/abs/2507.17634).
From the perspective of the last update, the effective lr of previous steps are: 0.2, 0.36, 0.488, ..., →1
in that sense maybe the cooldown fraction is worth tuning. (can be smaller)

Simple Ablation
increase iters to 5940, remove update smoothing, keep other changes the same
So, seems a little slower and doesn't hit the baseline. Not necessarily conclusive (better lr tuning might fix it), but at least this is suggestive that smoothing is helpful.

but the baseline (PR 124) is 5690?

@acutkosky
Copy link
Author

acutkosky commented Sep 22, 2025

oh cool. EMA on updates is similar EMA on weights: it has a LR-decay-like effect (see https://arxiv.org/abs/2507.17634). From the perspective of the last update, the effective lr of previous steps are: 0.2, 0.36, 0.488, ..., →1 in that sense maybe the cooldown fraction is worth tuning. (can be smaller)

Yeah, it's definitely a possibility. I can try a few more ablations.

but the baseline (PR 124) is 5690?

oops, I meant to write 5640 - basically just added a few extra iterations to roughly make up for the EMA's increased time-per-step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants