2020import torch .nn as nn
2121from torch .nn import functional as F
2222
23+
24+ @dataclass
25+ class GPTConfig :
26+ block_size : int = 1024
27+ vocab_size : int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
28+ n_layer : int = 12
29+ n_head : int = 12
30+ n_embd : int = 768
31+ dropout : float = 0.0
32+ bias : bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
33+
34+
35+ class LayerState :
36+ # the recurrent neural network (RNN) state for a layer of RWKV5.2
37+ def __init__ (self , x , cfg :GPTConfig ):
38+ # B, T, C, H, K = x.size(0), x.size(1), cfg.n_embed, cfg.n_heads, cfg.n_embed // cfg.n_heads
39+ B , T , C , H , K = x .size (0 ), x .size (1 ), cfg .n_embd , cfg .n_head , cfg .n_embd // cfg .n_head
40+ V = K
41+ # a (B,C) size tensor representing latest time mixer token embedding processed
42+ self .time_mixer_x_state = torch .zeros (B ,C ,dtype = x .dtype ,device = x .device )
43+ # an (B,H,K,V) size tensor representing a decaying token embedding memory for each head, where H=number_of_heads, K=key_dim_per_head, V=value_dim_per_head
44+ self .kv_state = torch .zeros (B ,H ,K ,V ,dtype = x .dtype ,device = x .device )
45+ # a (B,C) size tensor representing latest channel mixer token embedding processed
46+ self .channel_mixer_x_state = torch .zeros (B ,C ,dtype = x .dtype ,device = x .device )
47+
2348class LayerNorm (nn .Module ):
2449 """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
2550
@@ -138,6 +163,52 @@ def forward(self, x):
138163 y = self .dropout (self .output (y ))
139164 return y
140165
166+ def forward_step (self , x , state , kv_state ):
167+ print ('time mix forward_step x.shape:' , x .shape )
168+ B , T = x .size (0 ), 1
169+ C = 128 # ToDo(fix)
170+ H , N = self .n_head , self .head_size
171+ #
172+ # we divide a block into chunks to speed up computation & save vram.
173+ # you can try to find the optimal chunk_len for your GPU.
174+ # avoid going below 128 if you are using bf16 (otherwise time_decay might be less accurate).
175+ #
176+
177+ xx = state - x
178+ xk = x + xx * self .time_maa_k
179+ xv = x + xx * self .time_maa_v
180+ xr = x + xx * self .time_maa_r
181+ xg = x + xx * self .time_maa_g
182+ r = self .receptance (xr ).view (B , T , H , 1 , N )
183+ k = self .key (xk ).view (B , T , H , N , 1 )
184+ v = self .value (xv ).view (B , T , H , 1 , N )
185+ g = F .silu (self .gate (xg )) # extra gate
186+
187+ w = torch .exp (- torch .exp (self .time_decay .float ())).unsqueeze (- 1 ) # time_decay
188+ u = self .time_faaaa .float ().unsqueeze (- 1 ) # time_first
189+
190+ y = torch .empty (B , T , H , N , dtype = x .dtype , device = x .device )
191+ for t in range (T ):
192+ y [:,t ], kv_state = self .single_timestep (r [:,t ], k [:,t ], v [:,t ], u , w , kv_state )
193+
194+ y = y .transpose (1 , 2 ).contiguous ().view (B * T , C )
195+ y = self .ln_x (y ).view (B , T , C ) * g
196+
197+ # output projection
198+ y = self .dropout (self .output (y ))
199+ return y , x , kv_state
200+
201+ @staticmethod
202+ def single_timestep (r , k , v , u , w , kv_state ):
203+ y = kv_state # BHKV
204+ y = y + (k @ v ) * u # BHKV * HK1 + BHKV = BHKV
205+ out = r @ y # BH1K @ BHKV = BH1V
206+
207+ kv_state = kv_state * w # BHKV
208+ kv_state = kv_state + (k @ v ) # BHKV * HK1 + BHKV = BHKV
209+
210+ return out .squeeze (- 2 ), kv_state # BHV, BHKV
211+
141212class RWKV_ChannelMix_x051a (nn .Module ):
142213
143214 def __init__ (self , config , layer_id ):
@@ -169,6 +240,19 @@ def forward(self, x):
169240 x = self .dropout (x )
170241 return x
171242
243+ def forward_step (self , x , state ):
244+ xx = state - x
245+ xk = x + xx * self .time_maa_k
246+ xr = x + xx * self .time_maa_r
247+
248+ out = self .key (xk )
249+ out = torch .relu (out ) ** 2
250+ out = self .value (out )
251+ out = torch .sigmoid (self .receptance (xr )) * out
252+ out = self .dropout (out )
253+ return out , x
254+
255+
172256class Block (nn .Module ):
173257
174258 def __init__ (self , config , layer_id ):
@@ -183,15 +267,15 @@ def forward(self, x):
183267 x = x + self .cmix (self .ln_2 (x ))
184268 return x
185269
186- @ dataclass
187- class GPTConfig :
188- block_size : int = 1024
189- vocab_size : int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
190- n_layer : int = 12
191- n_head : int = 12
192- n_embd : int = 768
193- dropout : float = 0.0
194- bias : bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
270+ def forward_step ( self , x , s : LayerState ):
271+ out , s . time_mixer_x_state , s . kv_state = \
272+ self . tmix . forward_step ( self . ln_1 ( x ), s . time_mixer_x_state , s . kv_state )
273+ x = x + out
274+ out , s . channel_mixer_x_state = \
275+ self . cmix . forward_step ( self . ln_2 ( x ), s . channel_mixer_x_state )
276+ x = x + out
277+ return x , s
278+
195279
196280class GPT (nn .Module ):
197281
@@ -253,11 +337,13 @@ def forward(self, idx, targets=None):
253337
254338 # forward the GPT model itself
255339 tok_emb = self .transformer .wte (idx ) # token embeddings of shape (b, t, n_embd)
256- pos_emb = self .transformer .wpe (pos ) # position embeddings of shape (t, n_embd)
257- x = self .transformer .drop (tok_emb + pos_emb )
340+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
341+ x = self .transformer .drop (tok_emb ) # + pos_emb)
258342 for block in self .transformer .h :
259343 x = block (x )
260344 x = self .transformer .ln_f (x )
345+ return self .lm_head (x ), None
346+
261347
262348 if targets is not None :
263349 # if we are given some desired targets also calculate the loss
@@ -270,6 +356,16 @@ def forward(self, idx, targets=None):
270356
271357 return logits , loss
272358
359+ def forward_step (self , x , s ):
360+ tok_emb = self .transformer .wte (x ) # token embeddings of shape (b, n_embd)
361+ #pos_emb = self.transformer.wpe(pos) # position embeddings of shape (n_embd)
362+ x = self .transformer .drop (tok_emb ) # + pos_emb)
363+ for layer_id , block in enumerate (self .transformer .h ): # run each rwkv block
364+ x , s [layer_id ] = block .forward_step (x , s [layer_id ])
365+ x = self .transformer .ln_f (x )
366+ logits = self .lm_head (x )
367+ return logits , s
368+
273369 def crop_block_size (self , block_size ):
274370 # model surgery to decrease the block size if necessary
275371 # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
0 commit comments