Skip to content

Commit 12995a1

Browse files
authored
Remove casts in nanoGPT's layernorms (#662)
1 parent 6c3162a commit 12995a1

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

tripy/examples/nanogpt/model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,14 @@ def forward(self, x):
117117
class Block(tp.Module):
118118
def __init__(self, config):
119119
super().__init__()
120-
self.ln_1 = tp.LayerNorm(config.embedding_size)
120+
self.ln_1 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
121121
self.attn = CausalSelfAttention(config)
122-
self.ln_2 = tp.LayerNorm(config.embedding_size)
122+
self.ln_2 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
123123
self.mlp = MLP(config)
124124

125125
def forward(self, x):
126-
x_ln1 = tp.cast(self.ln_1(tp.cast(x, self.ln_1.dtype)), x.dtype)
127-
x = x + self.attn(x_ln1)
128-
x_ln2 = tp.cast(self.ln_2(tp.cast(x, self.ln_2.dtype)), x.dtype)
129-
x = x + self.mlp(x_ln2)
126+
x = x + self.attn(self.ln_1(x))
127+
x = x + self.mlp(self.ln_2(x))
130128
return x
131129

132130

@@ -137,15 +135,15 @@ def __init__(self, config):
137135
self.wte = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype)
138136
self.wpe = tp.Embedding(config.block_size, config.embedding_size, dtype=config.dtype)
139137
self.h = tp.Sequential(*[Block(config) for _ in range(config.num_layers)])
140-
self.ln_f = tp.LayerNorm(config.embedding_size)
138+
self.ln_f = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
141139

142140
def forward(self, idx):
143141
tok_emb = self.wte(idx) # token embeddings of shape (batch_size, seq_len, embedding_size)
144142
pos = tp.unsqueeze(tp.arange(self.seq_len, dtype=tp.int32)[: idx.shape[1]], 0)
145143
pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, embedding_size)
146144
x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size)
147145
x = self.h(x)
148-
x = tp.cast(self.ln_f(tp.cast(x, self.ln_f.dtype)), x.dtype)
146+
x = self.ln_f(x)
149147
return x
150148

151149

tripy/examples/nanogpt/weight_loader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def load_weights_from_hf(model, model_type, dtype):
5151
if any(key.endswith(w) for w in transposed):
5252
with torch.no_grad():
5353
weight = hf_state_dict[key].t().contiguous()
54-
if "ln" not in key:
55-
weight = weight.to(torch_dtype)
54+
weight = weight.to(torch_dtype)
5655
param = tp.Tensor(weight)
5756
tripy_state_dict[key] = param
5857

@@ -112,8 +111,7 @@ def get_submodule(module, attr_name):
112111
key, _ = key.split("quantizer._amax")
113112
key += "scale"
114113

115-
if "ln" not in key:
116-
weight = weight.to(torch_dtype)
114+
weight = weight.to(torch_dtype)
117115
param = tp.Tensor(weight.contiguous())
118116
assert key in expected_keys
119117
tripy_state_dict[key] = param

0 commit comments

Comments
 (0)