Skip to content

Conversation

@snimu
Copy link
Contributor

@snimu snimu commented Sep 11, 2025

Add additional input embedding

Includes PR#119.

Previously, modded-nanogpt medium added x0 to the residual at the input of every layer:

# GPT
def forward(self, input_seq, ...):
    ...

    x = x0 = norm(self.embed(input_seq[None]))

    ...

    for i in range(len(self.blocks)):
        ...
        x = self.blocks[i](x, x0, lambdas, ...)
        ...

# Block
def forward(self, x, x0, lambdas, ...):
    x = lambdas[0] * x + lambdas[1] * x0

Where lambdas are learned scalars.

This update adds another embedding module and adds it at every layer:

# GPT
def forward(self, input_seq, ...):
    ...

    x = x00 = norm(self.embed1(input_seq[None]))
    x01 = norm(self.embed2(input_seq[None]))

    ...

    for i in range(len(self.blocks)):
        ...
        x = self.blocks[i](x, x00, x01, lambdas, ...)
        ...

# Block
def forward(self, x, x00, x01, lambdas, ...):
    x = lambdas[0] * x + lambdas[1] * x00 + lambdas[2] * x02

While this slows down training, it increases learning per step, thus allowing us to reduce the step count to 5690.

Here are the resulting final validation losses over 19 runs:

[2.919502, 2.91976, 2.920582, 2.919331, 2.919008, 2.919827, 2.918785, 2.918519, 2.919297, 2.920061, 2.918938, 2.919342, 2.918186, 2.920546, 2.91954, 2.919093, 2.918951, 2.919599, 2.919956]

And these are the basic stats:

  • Mean: 2.9194117368421053
  • Median: 2.919342
  • Std: 0.000613243648848653
  • Min: 2.918186
  • Max: 2.920582

And t-test results:

{
    'n': 19,
    'sample_mean': 2.9194117368421053,
    'sample_std': 0.0006300479560324045,
    't_stat': -4.069816643193549,
    'p_value': 0.00035946114919240566,
    'alpha': 0.05,
    'decision': 'REJECT H0 (mean < threshold)',
    'upper_conf_bound_mean': 2.919662383449226,
    'threshold': 2.92
}

The final loss is below 2.92 with >99% likelihood.

Here are the corresponding run-times in seconds:

[1414.299, 1412.033, 1411.668, 1421.735, 1411.998, 1411.094, 1412.637, 1410.047, 1410.509, 1412.048, 1411.574, 1415.299, 1411.649, 1412.94, 1412.508, 1410.912, 1415.296, 1410.778, 1407.511]

Leading to the following stats:

  • Mean: 1412.4492105263155
  • Median: 1411.998
  • Std: 2.8062021268488864
  • Min: 1407.511
  • Max: 1421.735

The mean time is ~1412.5 seconds, or 23.54 minutes.

@ClassicLarry
Copy link
Collaborator

This accomplishes a very cool dynamic. Basically you have gotten rid of the concept of a token embedding vector. Instead of representing a token as a fixed point in d_model dimensional space, you have added a degree of freedom where a token embedding is represented by interpolating between two points defined by embed1 and embed2, and the interpolation point varies by layer.

Just a thought, the optimal interpolation point may vary by token- it could be worthwhile to extend lambdas[1] from size 1 to size vocab_size.

@snimu
Copy link
Contributor Author

snimu commented Sep 12, 2025

Thanks!

Interesting thought. You mean replacing the scalar lambdas with a 1D embedding layer? I'll have to think about that a bit more, might try it out next week though (or you can if you want to :)).

About the dynamic that developed in some of my runs at least, I have looked into this a bit more. From what I can tell, x00 is interpreted by the lm head as next-token predictions (must be 1-gram statistics), even though at the input it's equal to x. The nice thing about that is that it means minimal edit need of the input vector (x01 is suppressed at the early layers); the model only has to make slight corrections to x00 in order to make its prediction!
x01 is interpreted as a distribution over a bunch of synonyms of the input token, and basically a list of tokens that mean the same thing (" agent"; " Agent"; ...).

I'll release a blog post on it soon!

@ClassicLarry
Copy link
Collaborator

Yes, with 1D embedding layer. I tested both this PR and a simple lambda-vocab embedding on the short track with 8H100s and the runtime increase was too high to justify the addition.

self.lambda_embeds = nn.ModuleList([nn.Embedding(vocab_size, 1) for _ in range(num_layers)])
for lambda_embed in self.lambda_embeds:
  nn.init.ones_(lambda_embed.weight)

x0_active = x0*self.lambda_embeds[i](input_seq)
x = self.blocks[i](x, x0_active, lambdas[i], attn_args)

Strangely, the token lambda embedding gave a 1% speedup on A100 in google colab, but this didn't translate to H100.

@snimu
Copy link
Contributor Author

snimu commented Sep 12, 2025

I'm not surprised about the runtime issue; I actually experimented with adding two, three, four, or five additional embedding layers. With each one, the per-step loss fell, but less and less, so that timing-wise only one additional embedding was worth it (and just barely). I imagine that with the smaller (and thus shallower) model, the embeddings cause higher relative overhead than in the medium track. The lambda-embeddings might have the same effect: very slight loss improvement, bigger time penalty. If I have the time, I'll still try it on the medium track just to see.

One question I have is why you initialize all lambdas to 1. In my experiments, I noticed that it's really important to initialize all of them but the one for x to 0. Otherwise, the model will simply learn to route around the transformer layers and use embedding + lm head as the predictor (approximately).
Also, how did you handle the lambdas for x? Did you re-use the lambdas for x0? Or simply not multiply x with any lambdas?

@ClassicLarry
Copy link
Collaborator

Also, how did you handle the lambdas for x? Did you re-use the lambdas for x0? Or simply not multiply x with any lambdas?

I kept the existing lambdas in the block class:
x = lambdas[0] * x + lambdas[1] * x0
which initializes lambdas[1] to 0:
*[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas

only self.lambda_embeds was initialized to 1, so the full initial formula becomes x=1*x+1*0*x0, where the 1 in 1*0*x0 is a per token-layer param and 0 is a per layer param. Important to not have 0*0 params to avoid zero gradient trap. I still want the lambdas[1] size=(1) term in lambdas[1]*x0 so the x0 stream can be generally weighted correctly for rare tokens.

I was primarily proposing this as an extension of the additional input embedding idea you had because it has a high expression:compute ratio, and can be applied to one of x00 or x01 to modulate the x00:x01 ratio on a token level. It also follows the similar intuition of moving in the direction of more token level embeddings on the medium track as you added here:
ve = [ve[0], ve[1], ve[2], ve[3], ve[4]] + [None] * (len(self.blocks) - 10) + [ve[0], ve[1], ve[2], ve[3], ve[4]]
But since my test above on small track gave negative result I don't want to oversell this idea. Most ideas I have don't show improvements.

@snimu
Copy link
Contributor Author

snimu commented Sep 12, 2025

Ah, that makes sense.

@ClassicLarry ClassicLarry merged commit 34f5696 into KellerJordan:master Dec 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants