Skip to content

Commit c22dc6a

Browse files
author
N1kSt4r
committed
initially add nano-rwkv rnn mode
1 parent 381773e commit c22dc6a

File tree

2 files changed

+125
-16
lines changed

2 files changed

+125
-16
lines changed

model.py

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,31 @@
2020
import torch.nn as nn
2121
from torch.nn import functional as F
2222

23+
24+
@dataclass
25+
class GPTConfig:
26+
block_size: int = 1024
27+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
28+
n_layer: int = 12
29+
n_head: int = 12
30+
n_embd: int = 768
31+
dropout: float = 0.0
32+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
33+
34+
35+
class LayerState:
36+
# the recurrent neural network (RNN) state for a layer of RWKV5.2
37+
def __init__(self, x, cfg:GPTConfig):
38+
# B, T, C, H, K = x.size(0), x.size(1), cfg.n_embed, cfg.n_heads, cfg.n_embed // cfg.n_heads
39+
B, T, C, H, K = x.size(0), x.size(1), cfg.n_embd, cfg.n_head, cfg.n_embd // cfg.n_head
40+
V = K
41+
# a (B,C) size tensor representing latest time mixer token embedding processed
42+
self.time_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)
43+
# an (B,H,K,V) size tensor representing a decaying token embedding memory for each head, where H=number_of_heads, K=key_dim_per_head, V=value_dim_per_head
44+
self.kv_state = torch.zeros(B,H,K,V,dtype=x.dtype,device=x.device)
45+
# a (B,C) size tensor representing latest channel mixer token embedding processed
46+
self.channel_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)
47+
2348
class LayerNorm(nn.Module):
2449
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
2550

@@ -138,6 +163,52 @@ def forward(self, x):
138163
y = self.dropout(self.output(y))
139164
return y
140165

166+
def forward_step(self, x, state, kv_state):
167+
print('time mix forward_step x.shape:', x.shape)
168+
B, T = x.size(0), 1
169+
C = 128 # ToDo(fix)
170+
H, N = self.n_head, self.head_size
171+
#
172+
# we divide a block into chunks to speed up computation & save vram.
173+
# you can try to find the optimal chunk_len for your GPU.
174+
# avoid going below 128 if you are using bf16 (otherwise time_decay might be less accurate).
175+
#
176+
177+
xx = state - x
178+
xk = x + xx * self.time_maa_k
179+
xv = x + xx * self.time_maa_v
180+
xr = x + xx * self.time_maa_r
181+
xg = x + xx * self.time_maa_g
182+
r = self.receptance(xr).view(B, T, H, 1, N)
183+
k = self.key(xk).view(B, T, H, N, 1)
184+
v = self.value(xv).view(B, T, H, 1, N)
185+
g = F.silu(self.gate(xg)) # extra gate
186+
187+
w = torch.exp(-torch.exp(self.time_decay.float())).unsqueeze(-1) # time_decay
188+
u = self.time_faaaa.float().unsqueeze(-1) # time_first
189+
190+
y = torch.empty(B, T, H, N, dtype=x.dtype, device=x.device)
191+
for t in range(T):
192+
y[:,t], kv_state = self.single_timestep(r[:,t], k[:,t], v[:,t], u, w, kv_state)
193+
194+
y = y.transpose(1, 2).contiguous().view(B * T, C)
195+
y = self.ln_x(y).view(B, T, C) * g
196+
197+
# output projection
198+
y = self.dropout(self.output(y))
199+
return y, x, kv_state
200+
201+
@staticmethod
202+
def single_timestep(r, k, v, u, w, kv_state):
203+
y = kv_state # BHKV
204+
y = y + (k @ v) * u # BHKV * HK1 + BHKV = BHKV
205+
out = r @ y # BH1K @ BHKV = BH1V
206+
207+
kv_state = kv_state * w # BHKV
208+
kv_state = kv_state + (k @ v) # BHKV * HK1 + BHKV = BHKV
209+
210+
return out.squeeze(-2), kv_state # BHV, BHKV
211+
141212
class RWKV_ChannelMix_x051a(nn.Module):
142213

143214
def __init__(self, config, layer_id):
@@ -169,6 +240,19 @@ def forward(self, x):
169240
x = self.dropout(x)
170241
return x
171242

243+
def forward_step(self, x, state):
244+
xx = state - x
245+
xk = x + xx * self.time_maa_k
246+
xr = x + xx * self.time_maa_r
247+
248+
out = self.key(xk)
249+
out = torch.relu(out) ** 2
250+
out = self.value(out)
251+
out = torch.sigmoid(self.receptance(xr)) * out
252+
out = self.dropout(out)
253+
return out, x
254+
255+
172256
class Block(nn.Module):
173257

174258
def __init__(self, config, layer_id):
@@ -183,15 +267,15 @@ def forward(self, x):
183267
x = x + self.cmix(self.ln_2(x))
184268
return x
185269

186-
@dataclass
187-
class GPTConfig:
188-
block_size: int = 1024
189-
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
190-
n_layer: int = 12
191-
n_head: int = 12
192-
n_embd: int = 768
193-
dropout: float = 0.0
194-
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
270+
def forward_step(self, x, s: LayerState):
271+
out, s.time_mixer_x_state, s.kv_state = \
272+
self.tmix.forward_step(self.ln_1(x), s.time_mixer_x_state, s.kv_state)
273+
x = x + out
274+
out, s.channel_mixer_x_state = \
275+
self.cmix.forward_step(self.ln_2(x), s.channel_mixer_x_state)
276+
x = x + out
277+
return x, s
278+
195279

196280
class GPT(nn.Module):
197281

@@ -253,11 +337,13 @@ def forward(self, idx, targets=None):
253337

254338
# forward the GPT model itself
255339
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
256-
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
257-
x = self.transformer.drop(tok_emb + pos_emb)
340+
#pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
341+
x = self.transformer.drop(tok_emb) # + pos_emb)
258342
for block in self.transformer.h:
259343
x = block(x)
260344
x = self.transformer.ln_f(x)
345+
return self.lm_head(x), None
346+
261347

262348
if targets is not None:
263349
# if we are given some desired targets also calculate the loss
@@ -270,6 +356,16 @@ def forward(self, idx, targets=None):
270356

271357
return logits, loss
272358

359+
def forward_step(self, x, s):
360+
tok_emb = self.transformer.wte(x) # token embeddings of shape (b, n_embd)
361+
#pos_emb = self.transformer.wpe(pos) # position embeddings of shape (n_embd)
362+
x = self.transformer.drop(tok_emb) # + pos_emb)
363+
for layer_id, block in enumerate(self.transformer.h): # run each rwkv block
364+
x, s[layer_id] = block.forward_step(x, s[layer_id])
365+
x = self.transformer.ln_f(x)
366+
logits = self.lm_head(x)
367+
return logits, s
368+
273369
def crop_block_size(self, block_size):
274370
# model surgery to decrease the block size if necessary
275371
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)

sample.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import nullcontext
77
import torch
88
import tiktoken
9-
from model import GPTConfig, GPT
9+
from model import GPTConfig, GPT, LayerState
1010

1111
# -----------------------------------------------------------------------------
1212
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
@@ -78,13 +78,26 @@
7878
with open(start[5:], 'r', encoding='utf-8') as f:
7979
start = f.read()
8080
start_ids = encode(start)
81-
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
81+
82+
test_seq = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
83+
x = (torch.tensor(test_seq, dtype=torch.long, device=device)[None, ...])
8284

8385
# run generation
8486
with torch.no_grad():
8587
with ctx:
8688
for k in range(num_samples):
8789
print('(note: this is using "GPT-mode" for inference (very slow), so we limit it to 100 characters. The much faster "RNN-mode" for inference is coming soon)')
88-
y = model.generate(x, 100, temperature=temperature, top_k=top_k)
89-
print(decode(y[0].tolist()))
90-
print('---------------')
90+
# y = model.generate(x, 1, temperature=temperature, top_k=top_k)
91+
gt, _ = model.forward(x)
92+
93+
states = [LayerState(x, gptconf) for _ in range(gptconf.n_layer)]
94+
for i, test_token in enumerate(test_seq):
95+
x = torch.tensor([test_token], dtype=torch.long, device=device)
96+
x, states = model.forward_step(x, states)
97+
assert torch.allclose(gt[:, i], x[:, 0, :], rtol=1e-2), i
98+
99+
# print(y)
100+
# print(model.forward(y))
101+
# print(decode(y[0].tolist()))
102+
# print('---------------')
103+
break

0 commit comments

Comments
 (0)