@@ -117,16 +117,14 @@ def forward(self, x):
117117class Block (tp .Module ):
118118 def __init__ (self , config ):
119119 super ().__init__ ()
120- self .ln_1 = tp .LayerNorm (config .embedding_size )
120+ self .ln_1 = tp .LayerNorm (config .embedding_size , dtype = config . dtype )
121121 self .attn = CausalSelfAttention (config )
122- self .ln_2 = tp .LayerNorm (config .embedding_size )
122+ self .ln_2 = tp .LayerNorm (config .embedding_size , dtype = config . dtype )
123123 self .mlp = MLP (config )
124124
125125 def forward (self , x ):
126- x_ln1 = tp .cast (self .ln_1 (tp .cast (x , self .ln_1 .dtype )), x .dtype )
127- x = x + self .attn (x_ln1 )
128- x_ln2 = tp .cast (self .ln_2 (tp .cast (x , self .ln_2 .dtype )), x .dtype )
129- x = x + self .mlp (x_ln2 )
126+ x = x + self .attn (self .ln_1 (x ))
127+ x = x + self .mlp (self .ln_2 (x ))
130128 return x
131129
132130
@@ -137,15 +135,15 @@ def __init__(self, config):
137135 self .wte = tp .Embedding (config .vocab_size , config .embedding_size , dtype = config .dtype )
138136 self .wpe = tp .Embedding (config .block_size , config .embedding_size , dtype = config .dtype )
139137 self .h = tp .Sequential (* [Block (config ) for _ in range (config .num_layers )])
140- self .ln_f = tp .LayerNorm (config .embedding_size )
138+ self .ln_f = tp .LayerNorm (config .embedding_size , dtype = config . dtype )
141139
142140 def forward (self , idx ):
143141 tok_emb = self .wte (idx ) # token embeddings of shape (batch_size, seq_len, embedding_size)
144142 pos = tp .unsqueeze (tp .arange (self .seq_len , dtype = tp .int32 )[: idx .shape [1 ]], 0 )
145143 pos_emb = self .wpe (pos ) # position embeddings of shape (seq_len, embedding_size)
146144 x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size)
147145 x = self .h (x )
148- x = tp . cast ( self .ln_f (tp . cast ( x , self . ln_f . dtype )), x . dtype )
146+ x = self .ln_f (x )
149147 return x
150148
151149
0 commit comments