-
Notifications
You must be signed in to change notification settings - Fork 510
New medium track WR: Second input embedding (1412 seconds); includes #119 #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This accomplishes a very cool dynamic. Basically you have gotten rid of the concept of a token embedding vector. Instead of representing a token as a fixed point in d_model dimensional space, you have added a degree of freedom where a token embedding is represented by interpolating between two points defined by embed1 and embed2, and the interpolation point varies by layer. Just a thought, the optimal interpolation point may vary by token- it could be worthwhile to extend lambdas[1] from size 1 to size vocab_size. |
|
Thanks! Interesting thought. You mean replacing the scalar lambdas with a 1D embedding layer? I'll have to think about that a bit more, might try it out next week though (or you can if you want to :)). About the dynamic that developed in some of my runs at least, I have looked into this a bit more. From what I can tell, x00 is interpreted by the lm head as next-token predictions (must be 1-gram statistics), even though at the input it's equal to x. The nice thing about that is that it means minimal edit need of the input vector (x01 is suppressed at the early layers); the model only has to make slight corrections to x00 in order to make its prediction! I'll release a blog post on it soon! |
|
Yes, with 1D embedding layer. I tested both this PR and a simple lambda-vocab embedding on the short track with 8H100s and the runtime increase was too high to justify the addition. Strangely, the token lambda embedding gave a 1% speedup on A100 in google colab, but this didn't translate to H100. |
|
I'm not surprised about the runtime issue; I actually experimented with adding two, three, four, or five additional embedding layers. With each one, the per-step loss fell, but less and less, so that timing-wise only one additional embedding was worth it (and just barely). I imagine that with the smaller (and thus shallower) model, the embeddings cause higher relative overhead than in the medium track. The lambda-embeddings might have the same effect: very slight loss improvement, bigger time penalty. If I have the time, I'll still try it on the medium track just to see. One question I have is why you initialize all lambdas to 1. In my experiments, I noticed that it's really important to initialize all of them but the one for x to 0. Otherwise, the model will simply learn to route around the transformer layers and use embedding + lm head as the predictor (approximately). |
I kept the existing lambdas in the block class: only self.lambda_embeds was initialized to 1, so the full initial formula becomes I was primarily proposing this as an extension of the additional input embedding idea you had because it has a high expression:compute ratio, and can be applied to one of x00 or x01 to modulate the x00:x01 ratio on a token level. It also follows the similar intuition of moving in the direction of more token level embeddings on the medium track as you added here: |
|
Ah, that makes sense. |
Add additional input embedding
Includes PR#119.
Previously, modded-nanogpt medium added x0 to the residual at the input of every layer:
Where
lambdasare learned scalars.This update adds another embedding module and adds it at every layer:
While this slows down training, it increases learning per step, thus allowing us to reduce the step count to 5690.
Here are the resulting final validation losses over 19 runs:
And these are the basic stats:
And t-test results:
{ 'n': 19, 'sample_mean': 2.9194117368421053, 'sample_std': 0.0006300479560324045, 't_stat': -4.069816643193549, 'p_value': 0.00035946114919240566, 'alpha': 0.05, 'decision': 'REJECT H0 (mean < threshold)', 'upper_conf_bound_mean': 2.919662383449226, 'threshold': 2.92 }The final loss is below 2.92 with >99% likelihood.
Here are the corresponding run-times in seconds:
Leading to the following stats:
The mean time is ~1412.5 seconds, or 23.54 minutes.