diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..f1d8449d8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + types: [python] + - id: trailing-whitespace + + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black \ No newline at end of file diff --git a/RWKV-v1/src/model.py b/RWKV-v1/src/model.py index 1eeb0868a..9ea17885c 100644 --- a/RWKV-v1/src/model.py +++ b/RWKV-v1/src/model.py @@ -7,52 +7,69 @@ import torch import torch.nn as nn from torch.nn import functional as F + logger = logging.getLogger(__name__) ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## -def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module + +def RWKV_Init( + module, config +): # fancy initialization of all lin & emb layer in the module for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue with torch.no_grad(): - name = '[unknown weight]' - for name, parameter in module.named_parameters(): # find the name of the weight + name = "[unknown weight]" + for ( + name, + parameter, + ) in module.named_parameters(): # find the name of the weight if id(m.weight) == id(parameter): break shape = m.weight.data.shape gain = 1.0 # positive: gain for orthogonal, negative: std for normal - scale = 1.0 # extra scale for gain + scale = 1.0 # extra scale for gain if isinstance(m, nn.Linear): if m.bias is not None: m.bias.data.zero_() if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # final projection? scale = config.rwkv_emb_scale if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # token emb? scale = config.rwkv_emb_scale - if hasattr(m, 'scale_init'): + if hasattr(m, "scale_init"): scale = m.scale_init - print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) + print( + str(shape[0]).ljust(5), + str(shape[1]).ljust(5), + f"{round(scale,2):g}".ljust(4), + name, + ) gain *= scale if gain == 0: - nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices + nn.init.zeros_(m.weight) # zero init is great for some RWKV matrices elif gain > 0: nn.init.orthogonal_(m.weight, gain=gain) else: nn.init.normal_(m.weight, mean=0, std=-gain) + class RWKV_TimeMix(nn.Module): def __init__(self, config, layer_id): super().__init__() @@ -62,12 +79,16 @@ def __init__(self, config, layer_id): self.n_head = config.n_head self.head_size = config.n_attn // config.n_head - with torch.no_grad(): # initial time_w curves for better convergence + with torch.no_grad(): # initial time_w curves for better convergence ww = torch.ones(config.n_head, config.ctx_len) - curve = torch.tensor([-(config.ctx_len - 1 - i) for i in range(config.ctx_len)]) # the distance + curve = torch.tensor( + [-(config.ctx_len - 1 - i) for i in range(config.ctx_len)] + ) # the distance for h in range(config.n_head): if h < config.n_head - 1: - decay_speed = math.pow(config.ctx_len, -(h+1)/(config.n_head-1)) + decay_speed = math.pow( + config.ctx_len, -(h + 1) / (config.n_head - 1) + ) else: decay_speed = 0.0 ww[h] = torch.exp(curve * decay_speed) @@ -77,8 +98,8 @@ def __init__(self, config, layer_id): self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) - - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) @@ -99,10 +120,10 @@ def forward(self, x): w = F.pad(self.time_w, (0, TT)) w = torch.tile(w, [TT]) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) - w = w[:, :, TT-1:] # w is now a circulant matrix + w = w[:, :, TT - 1 :] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) + x = torch.cat([self.time_shift(x[:, :, : C // 2]), x[:, :, C // 2 :]], dim=-1) # if hasattr(self, 'tiny_att'): # tiny_att = self.tiny_att(x, self.mask) @@ -110,13 +131,13 @@ def forward(self, x): v = self.value(x) r = self.receptance(x) - k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13 + k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13 k = torch.exp(k) sum_k = torch.cumsum(k, dim=1) kv = (k * v).view(B, T, self.n_head, self.head_size) - wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, -1) + wkv = (torch.einsum("htu,buhc->bthc", w, kv)).contiguous().view(B, T, -1) rwkv = torch.sigmoid(r) * wkv / sum_k @@ -126,13 +147,16 @@ def forward(self, x): return rwkv * self.time_gamma[:T, :] + class RWKV_ChannelMix(nn.Module): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) - - hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + hidden_sz = ( + 5 * config.n_ffn // 2 + ) # can use smaller hidden_sz because of receptance gating self.key = nn.Linear(config.n_embd, hidden_sz) self.value = nn.Linear(config.n_embd, hidden_sz) self.weight = nn.Linear(hidden_sz, config.n_embd) @@ -143,19 +167,20 @@ def __init__(self, config, layer_id): def forward(self, x): B, T, C = x.size() - - x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) + + x = torch.cat([self.time_shift(x[:, :, : C // 2]), x[:, :, C // 2 :]], dim=-1) k = self.key(x) v = self.value(x) r = self.receptance(x) - - wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu + + wkv = self.weight(F.mish(k) * v) # i find mish is a bit better than gelu rwkv = torch.sigmoid(r) * wkv return rwkv -class RWKV_TinyAttn(nn.Module): # extra tiny attention + +class RWKV_TinyAttn(nn.Module): # extra tiny attention def __init__(self, config): super().__init__() self.d_attn = config.rwkv_tiny_attn @@ -168,32 +193,44 @@ def __init__(self, config): def forward(self, x, mask): B, T, C = x.size() qkv = self.qkv(x) - q, k, v = qkv.chunk(3, dim = -1) + q, k, v = qkv.chunk(3, dim=-1) if self.n_head > 1: - q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - - qk = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size)) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) - qk = qk.masked_fill(mask == 0, float('-inf')) - qk = F.softmax(qk, dim = -1) - qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + q = q.view(B, T, self.n_head, self.head_size).transpose( + 1, 2 + ) # (B, T, C) -> (B, nh, T, hs) + k = k.view(B, T, self.n_head, self.head_size).transpose( + 1, 2 + ) # (B, T, C) -> (B, nh, T, hs) + v = v.view(B, T, self.n_head, self.head_size).transpose( + 1, 2 + ) # (B, T, C) -> (B, nh, T, hs) + + qk = (q @ k.transpose(-2, -1)) * ( + 1.0 / math.sqrt(self.head_size) + ) # (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + qk = qk.masked_fill(mask == 0, float("-inf")) + qk = F.softmax(qk, dim=-1) + qkv = qk @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) if self.n_head > 1: - qkv = qkv.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) - + qkv = ( + qkv.transpose(1, 2).contiguous().view(B, T, -1) + ) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + return self.out(qkv) + ######################################################################################################## # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN ######################################################################################################## + class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None @@ -202,23 +239,26 @@ def forward(self, x, seq_len=None): if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos() self.sin_cached = emb.sin() return self.cos_cached, self.sin_cached + def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), -1) + @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): - cos, sin = cos[...,:q.shape[-2],:], sin[...,:q.shape[-2],:] + cos, sin = cos[..., : q.shape[-2], :], sin[..., : q.shape[-2], :] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + class MHA_rotary(nn.Module): - def __init__(self, config, layer_id, time_shift = False): + def __init__(self, config, layer_id, time_shift=False): super().__init__() self.layer_id = layer_id assert config.n_attn % config.n_head == 0 @@ -227,14 +267,16 @@ def __init__(self, config, layer_id, time_shift = False): self.head_size = config.n_attn // config.n_head if time_shift: - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.query = nn.Linear(config.n_embd, config.n_attn) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) - + self.register_buffer( + "mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) + self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) @@ -243,37 +285,50 @@ def __init__(self, config, layer_id, time_shift = False): def forward(self, x): B, T, C = x.size() - if hasattr(self, 'time_shift'): - x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) - - q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - - q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] - k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] + if hasattr(self, "time_shift"): + x = torch.cat( + [self.time_shift(x[:, :, : C // 2]), x[:, :, C // 2 :]], dim=-1 + ) + + q = ( + self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + k = ( + self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + v = ( + self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + + q, query_pass = q[..., : self.rotary_ndims], q[..., self.rotary_ndims :] + k, key_pass = k[..., : self.rotary_ndims], k[..., self.rotary_ndims :] cos, sin = self.rotary_emb(q, seq_len=T) - q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding + q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding q = torch.cat((q, query_pass), dim=-1) k = torch.cat((k, key_pass), dim=-1) - - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) - att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask - att = F.softmax(att, dim = -1) # softmax - x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) - x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + att = (q @ k.transpose(-2, -1)) * ( + 1.0 / math.sqrt(k.size(-1)) + ) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf")) # causal mask + att = F.softmax(att, dim=-1) # softmax + + x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + x = ( + x.transpose(1, 2).contiguous().view(B, T, -1) + ) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) x = self.output(x) return x + class GeGLU(torch.nn.Module): - def __init__(self, config, layer_id, time_shift = False): + def __init__(self, config, layer_id, time_shift=False): super().__init__() self.layer_id = layer_id if time_shift: - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) hidden_sz = 3 * config.n_ffn self.key = nn.Linear(config.n_embd, hidden_sz) @@ -282,18 +337,22 @@ def __init__(self, config, layer_id, time_shift = False): def forward(self, x): B, T, C = x.size() - if hasattr(self, 'time_shift'): - x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) - + if hasattr(self, "time_shift"): + x = torch.cat( + [self.time_shift(x[:, :, : C // 2]), x[:, :, C // 2 :]], dim=-1 + ) + k = self.key(x) - v = self.value(x) + v = self.value(x) y = self.weight(F.gelu(k) * v) return y + ######################################################################################################## # MHA_pro: with more tricks ######################################################################################################## + class MHA_pro(nn.Module): def __init__(self, config, layer_id): super().__init__() @@ -307,17 +366,21 @@ def __init__(self, config, layer_id): self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len)) self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1)) self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1)) - self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer( + "mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.query = nn.Linear(config.n_embd, config.n_attn) self.key = nn.Linear(config.n_embd, config.n_attn) self.value = nn.Linear(config.n_embd, config.n_attn) - + self.rotary_ndims = int(self.head_size * 0.5) self.rotary_emb = RotaryEmbedding(self.rotary_ndims) - self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads + self.head_mix = nn.Conv2d( + self.n_head, self.n_head, kernel_size=1, bias=False + ) # talking heads self.output = nn.Linear(config.n_attn, config.n_embd) @@ -327,41 +390,55 @@ def forward(self, x): w = F.pad(self.time_w, (0, TT)) w = torch.tile(w, [TT]) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) - w = w[:, :, TT-1:] # w is now a circulant matrix + w = w[:, :, TT - 1 :] # w is now a circulant matrix w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :] - x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing - q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs) - - q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:] - k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:] + x = torch.cat( + [self.time_shift(x[:, :, : C // 2]), x[:, :, C // 2 :]], dim=-1 + ) # time-shift mixing + q = ( + self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + k = ( + self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + v = ( + self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) + ) # (B, T, C) -> (B, nh, T, hs) + + q, query_pass = q[..., : self.rotary_ndims], q[..., self.rotary_ndims :] + k, key_pass = k[..., : self.rotary_ndims], k[..., self.rotary_ndims :] cos, sin = self.rotary_emb(q, seq_len=T) - q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding + q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding q = torch.cat((q, query_pass), dim=-1) - k = torch.cat((k, key_pass), dim=-1) - - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) - att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf')) # causal mask - att = F.softmax(att, dim = -1) # softmax - att = att * w # time-weighting - att = self.head_mix(att) # talking heads + k = torch.cat((k, key_pass), dim=-1) - x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) - x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) + att = (q @ k.transpose(-2, -1)) * ( + 1.0 / math.sqrt(k.size(-1)) + ) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T) + att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf")) # causal mask + att = F.softmax(att, dim=-1) # softmax + att = att * w # time-weighting + att = self.head_mix(att) # talking heads + + x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs) + x = ( + x.transpose(1, 2).contiguous().view(B, T, -1) + ) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C) x = self.output(x) * self.time_gamma[:T, :] return x + ######################################################################################################## # The GPT Model with our blocks ######################################################################################################## + class RMSNorm(nn.Module): def __init__(self, d): super().__init__() - self.dd = d ** (-1. / 2) + self.dd = d ** (-1.0 / 2) self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): @@ -369,25 +446,29 @@ def forward(self, x): x_normed = x / (norm_x * self.dd + 1e-12) return self.weight * x_normed + class FixedNorm(nn.Module): def __init__(self, d): super().__init__() - self.dd = d ** (-1. / 2) + self.dd = d ** (-1.0 / 2) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) x_normed = x / (norm_x * self.dd + 1e-12) return x_normed + ######################################################################################################## + class GPTConfig: def __init__(self, vocab_size, ctx_len, **kwargs): self.vocab_size = vocab_size self.ctx_len = ctx_len - for k,v in kwargs.items(): + for k, v in kwargs.items(): setattr(self, k, v) + class Block(nn.Module): def __init__(self, config, layer_id): super().__init__() @@ -396,21 +477,21 @@ def __init__(self, config, layer_id): self.ln1 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd) - if config.model_type == 'RWKV': + if config.model_type == "RWKV": # self.ln1 = FixedNorm(config.n_embd) # self.ln2 = FixedNorm(config.n_embd) self.attn = RWKV_TimeMix(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) - elif config.model_type == 'MHA_rotary': + elif config.model_type == "MHA_rotary": self.attn = MHA_rotary(config, layer_id) self.mlp = GeGLU(config, layer_id) - - elif config.model_type == 'MHA_shift': + + elif config.model_type == "MHA_shift": self.attn = MHA_rotary(config, layer_id, time_shift=True) self.mlp = GeGLU(config, layer_id, time_shift=True) - - elif config.model_type == 'MHA_pro': + + elif config.model_type == "MHA_pro": self.attn = MHA_pro(config, layer_id) self.mlp = RWKV_ChannelMix(config, layer_id) @@ -418,9 +499,10 @@ def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) - + return x + class GPT(nn.Module): def __init__(self, config): super().__init__() @@ -431,23 +513,29 @@ def __init__(self, config): self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) - self.time_out = nn.Parameter(torch.ones(1,config.ctx_len,1)) # reduce confidence of early tokens + self.time_out = nn.Parameter( + torch.ones(1, config.ctx_len, 1) + ) # reduce confidence of early tokens self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.head_q = nn.Linear(config.n_embd, 256) self.head_q.scale_init = 0.01 self.head_k = nn.Linear(config.n_embd, 256) self.head_k.scale_init = 0.01 - self.register_buffer("copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) self.ctx_len = config.ctx_len - if self.config.model_type == 'RWKV': + if self.config.model_type == "RWKV": RWKV_Init(self, config) else: self.apply(self._init_weights) - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + logger.info( + "number of parameters: %e", sum(p.numel() for p in self.parameters()) + ) def get_ctx_len(self): return self.ctx_len @@ -463,32 +551,48 @@ def configure_optimizers(self, train_config): decay = set() no_decay = set() - whitelist_weight_modules = (nn.Linear, ) + whitelist_weight_modules = (nn.Linear,) blacklist_weight_modules = (RMSNorm, nn.LayerNorm, nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name - if pn.endswith('bias') or ('time' in fpn) or ('head' in fpn): + if pn.endswith("bias") or ("time" in fpn) or ("head" in fpn): no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): no_decay.add(fpn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": train_config.weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, ] - optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + optimizer = torch.optim.AdamW( + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas, + eps=train_config.eps, + ) return optimizer def forward(self, idx, targets=None): @@ -501,13 +605,13 @@ def forward(self, idx, targets=None): x = self.ln_f(x) - q = self.head_q(x)[:,:T,:] - k = self.head_k(x)[:,:T,:] + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / 256) - c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0) - c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float() + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).float() - x = x * self.time_out[:, :T, :] # reduce confidence of early tokens + x = x * self.time_out[:, :T, :] # reduce confidence of early tokens x = self.head(x) + c loss = None diff --git a/RWKV-v1/src/trainer.py b/RWKV-v1/src/trainer.py index 5f88fcc70..959d8a3f2 100644 --- a/RWKV-v1/src/trainer.py +++ b/RWKV-v1/src/trainer.py @@ -6,11 +6,13 @@ import torch.optim as optim from torch.optim.lr_scheduler import LambdaLR from torch.utils.data.dataloader import DataLoader + logger = logging.getLogger(__name__) # print('logging to wandb... (comment it if you don\'t have wandb)') # import wandb # comment this if you don't have wandb + class TrainerConfig: max_epochs = 10 batch_size = 64 @@ -19,19 +21,19 @@ class TrainerConfig: eps = 1e-8 grad_norm_clip = 1.0 weight_decay = 0.01 - lr_decay = False # linear warmup followed by cosine decay - warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper - final_tokens = 260e9 # at which point do we reach lr_final + lr_decay = False # linear warmup followed by cosine decay + warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper + final_tokens = 260e9 # at which point do we reach lr_final epoch_save_frequency = 0 - epoch_save_path = 'trained-' - num_workers = 0 # for DataLoader + epoch_save_path = "trained-" + num_workers = 0 # for DataLoader def __init__(self, **kwargs): - for k,v in kwargs.items(): + for k, v in kwargs.items(): setattr(self, k, v) -class Trainer: +class Trainer: def __init__(self, model, train_dataset, test_dataset, config): self.model = model self.train_dataset = train_dataset @@ -40,21 +42,38 @@ def __init__(self, model, train_dataset, test_dataset, config): self.avg_loss = -1 self.steps = 0 - if 'wandb' in sys.modules: + if "wandb" in sys.modules: cfg = model.config for k in config.__dict__: - setattr(cfg, k, config.__dict__[k]) # combine cfg - wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) - - self.device = 'cpu' - if torch.cuda.is_available(): # take over whatever gpus are on the system + setattr(cfg, k, config.__dict__[k]) # combine cfg + wandb.init( + project="RWKV-LM", + name=self.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), + config=cfg, + save_code=False, + ) + + self.device = "cpu" + if torch.cuda.is_available(): # take over whatever gpus are on the system self.device = torch.cuda.current_device() self.model = torch.nn.DataParallel(self.model).to(self.device) def get_run_name(self): raw_model = self.model.module if hasattr(self.model, "module") else self.model cfg = raw_model.config - run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + run_name = ( + str(cfg.vocab_size) + + "-" + + str(cfg.ctx_len) + + "-" + + cfg.model_type + + "-" + + str(cfg.n_layer) + + "-" + + str(cfg.n_embd) + ) return run_name def train(self): @@ -63,52 +82,82 @@ def train(self): optimizer = raw_model.configure_optimizers(config) def run_epoch(split): - is_train = split == 'train' + is_train = split == "train" model.train(is_train) data = self.train_dataset if is_train else self.test_dataset - loader = DataLoader(data, shuffle=True, pin_memory=True, - batch_size=config.batch_size, - num_workers=config.num_workers) + loader = DataLoader( + data, + shuffle=True, + pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + pbar = ( + tqdm( + enumerate(loader), + total=len(loader), + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) + if is_train + else enumerate(loader) + ) - pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) - for it, (x, y) in pbar: - x = x.to(self.device) # place data on the correct device + x = x.to(self.device) # place data on the correct device y = y.to(self.device) - + with torch.set_grad_enabled(is_train): - _, loss = model(x, y) # forward the model - loss = loss.mean() # collapse all losses if they are scattered on multiple gpus + _, loss = model(x, y) # forward the model + loss = ( + loss.mean() + ) # collapse all losses if they are scattered on multiple gpus - if is_train: # backprop and update the parameters + if is_train: # backprop and update the parameters model.zero_grad() loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) + torch.nn.utils.clip_grad_norm_( + model.parameters(), config.grad_norm_clip + ) optimizer.step() - - if config.lr_decay: # decay the learning rate based on our progress - self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + + if config.lr_decay: # decay the learning rate based on our progress + self.tokens += ( + y >= 0 + ).sum() # number of tokens processed this step (i.e. label is not -100) lr_final_factor = config.lr_final / config.learning_rate if self.tokens < config.warmup_tokens: # linear warmup - lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens) + lr_mult = lr_final_factor + (1 - lr_final_factor) * float( + self.tokens + ) / float(config.warmup_tokens) progress = 0 else: # cosine learning rate decay - progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + progress = float( + self.tokens - config.warmup_tokens + ) / float( + max(1, config.final_tokens - config.warmup_tokens) + ) # progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR - lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 + lr_mult = (0.5 + lr_final_factor / 2) + ( + 0.5 - lr_final_factor / 2 + ) * math.cos( + math.pi * progress + ) # better 1.0 ~ 0.1 lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr else: lr = config.learning_rate - now_loss = loss.item() # report progress - - if 'wandb' in sys.modules: - wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size) + now_loss = loss.item() # report progress + + if "wandb" in sys.modules: + wandb.log( + {"loss": now_loss}, step=self.steps * self.config.batch_size + ) self.steps += 1 if self.avg_loss < 0: @@ -116,15 +165,28 @@ def run_epoch(split): else: # factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) factor = 1 / (it + 1) - self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor - pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") + self.avg_loss = ( + self.avg_loss * (1.0 - factor) + now_loss * factor + ) + pbar.set_description( + f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}" + ) while True: - self.tokens = 0 # counter used for learning rate decay + self.tokens = 0 # counter used for learning rate decay for epoch in range(config.max_epochs): - run_epoch('train') - - if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): - raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module - torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth') + run_epoch("train") + + if ( + self.config.epoch_save_frequency > 0 + and epoch % self.config.epoch_save_frequency == 0 + ) or (epoch == config.max_epochs - 1): + raw_model = ( + self.model.module + if hasattr(self.model, "module") + else self.model + ) # DataParallel wrappers keep raw model object in .module + torch.save( + raw_model, self.config.epoch_save_path + str(epoch + 1) + ".pth" + ) diff --git a/RWKV-v1/src/utils.py b/RWKV-v1/src/utils.py index 5f9bb650d..810e76e90 100644 --- a/RWKV-v1/src/utils.py +++ b/RWKV-v1/src/utils.py @@ -4,12 +4,14 @@ import torch.nn as nn from torch.nn import functional as F + def top_k_logits(logits, k): v, ix = torch.topk(logits, k) out = logits.clone() - out[out < v[:, [-1]]] = -float('Inf') + out[out < v[:, [-1]]] = -float("Inf") return out + def top_p_probs(probs, p): out = probs.clone() @@ -17,32 +19,42 @@ def top_p_probs(probs, p): cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 + sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] out[indices_to_remove] = 0 return out + # top-p + top-k + pow&ratio sampling -def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): +def sample_logits( + logits, + pos, + temperature=1.0, + top_k=None, + top_p=None, + min_p_pow=None, + min_p_ratio=None, +): logits = logits[:, pos, :] / temperature probs = F.softmax(logits, dim=-1) - + if min_p_ratio is not None: limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio - logits[probs < limit] = -float('Inf') - + logits[probs < limit] = -float("Inf") + if top_k is not None: logits = top_k_logits(logits, top_k) - + probs = F.softmax(logits, dim=-1) - + if top_p is not None: probs[0] = top_p_probs(probs[0], top_p) - + ix = torch.multinomial(probs, num_samples=1) return ix[0][0].cpu() + def set_seed(seed): random.seed(seed) np.random.seed(seed) diff --git a/RWKV-v1/train.py b/RWKV-v1/train.py index ab370e1b2..ceda5a841 100644 --- a/RWKV-v1/train.py +++ b/RWKV-v1/train.py @@ -12,18 +12,22 @@ set_seed(42) np.set_printoptions(precision=4, suppress=True, linewidth=200) -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) # RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance # MHA_rotary : usual MultiheadAttention+Rotary+GeGLU - not as good # MHA_shift : with time-shift - good performance # MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance -model_type = 'RWKV' +model_type = "RWKV" # datafile = u"V:\\NLP\\text8" # datafile = u"V:\\NLP\\enwik8" -datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" -datafile_encoding = 'utf-8' +datafile = "V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" +datafile_encoding = "utf-8" # datafile = u"D:\\NLP-Data\\ww100M.txt" # datafile = u"D:\\NLP-Data\\__2019.txt" # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" @@ -32,59 +36,71 @@ # datafile = u"V:\\NLP\\simplebooks-shift-utf32.word" # datafile_encoding = 'utf-32' -datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs +datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs #################################### VERY IMPORTANT #################################### -epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. -epoch_save_path = 'trained-' +epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. +epoch_save_path = "trained-" -batch_size = 32 # if you see "CUDA out of memory", reduce this. - # if you have good GPU, increase this. - # use GPU-Z to find the highest value for your VRAM. +batch_size = 32 # if you see "CUDA out of memory", reduce this. +# if you have good GPU, increase this. +# use GPU-Z to find the highest value for your VRAM. -n_epoch = 100 # the 'epoch' here is actually very short (and of fixed length) +n_epoch = 100 # the 'epoch' here is actually very short (and of fixed length) ######################################################################################## -model_level = 'character' # 'character' (recommended) or 'word' +model_level = "character" # 'character' (recommended) or 'word' -ctx_len = 256 # context length, try 512 or 1024 if you have good GPU -n_layer = 6 # try 12 for 100M, 24 for 300M -n_head = 8 # try 12 for 100M, 16 for 300M +ctx_len = 256 # context length, try 512 or 1024 if you have good GPU +n_layer = 6 # try 12 for 100M, 24 for 300M +n_head = 8 # try 12 for 100M, 16 for 300M n_embd = n_head * 64 n_attn = n_embd n_ffn = n_embd -lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 +lr_init = ( + 6e-4 if model_type == "RWKV" else 4e-4 +) # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 lr_final = 4e-5 -betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99) +betas = (0.9, 0.99) if model_type == "RWKV" else (0.9, 0.99) eps = 4e-9 -weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data +weight_decay = ( + 0 if model_type == "RWKV" else 0.01 +) # wd is not useful when we have enough data -epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +epoch_length_fixed = ( + 10000 # make an 'epoch' very short, so we can see the training progress +) ######## special hyperparameters for RWKV model ######## -rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice -rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english -rwkv_tiny_head = 1 # 1 is good enough. 8 is slow -# n_side_proj = 512 # extra 'side projection', quite useful for BPE models +rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice +rwkv_tiny_attn = 0 # 64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english +rwkv_tiny_head = 1 # 1 is good enough. 8 is slow +# n_side_proj = 512 # extra 'side projection', quite useful for BPE models ######################################################################################################## # Load data ######################################################################################################## -print('loading data... ' + datafile) +print("loading data... " + datafile) + class Dataset(Dataset): def __init__(self, data, model_level, ctx_len): - print('building token list...', end=' ') - if model_level == 'word': + print("building token list...", end=" ") + if model_level == "word": import re - data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data) - data = re.sub(' +',' ',data) - print('splitting token...') - data = data.lower().split(' ') + + data = re.sub( + r"(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)", + r" \g<0> ", + data, + ) + data = re.sub(" +", " ", data) + print("splitting token...") + data = data.lower().split(" ") unique = sorted(list(set(data))) # print() # for u in unique: @@ -96,13 +112,13 @@ def __init__(self, data, model_level, ctx_len): for u in unique: xxObj[xx] = u xx += 1 - with open('vocab.json', "w", encoding="utf-16") as vocab_file: + with open("vocab.json", "w", encoding="utf-16") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) data_size, vocab_size = len(data), len(unique) - print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) - self.stoi = { ch:i for i,ch in enumerate(unique) } - self.itos = { i:ch for i,ch in enumerate(unique) } + print("data has %d %ss, %d unique." % (data_size, model_level, vocab_size)) + self.stoi = {ch: i for i, ch in enumerate(unique)} + self.itos = {i: ch for i, ch in enumerate(unique)} self.ctx_len = ctx_len self.vocab_size = vocab_size self.data = data @@ -111,32 +127,94 @@ def __len__(self): return epoch_length_fixed def __getitem__(self, idx): - i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # cheat: pick a random spot in dataset - chunk = self.data[i:i+self.ctx_len+1] + i = np.random.randint( + 0, len(self.data) - (self.ctx_len + 1) + ) # cheat: pick a random spot in dataset + chunk = self.data[i : i + self.ctx_len + 1] dix = [self.stoi[s] for s in chunk] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) return x, y -train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len) + +train_dataset = Dataset( + open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len +) ######################################################################################################## # Train model ######################################################################################################## -model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, - rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, - n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) +model = GPT( + GPTConfig( + train_dataset.vocab_size, + train_dataset.ctx_len, + model_type=model_type, + rwkv_emb_scale=rwkv_emb_scale, + rwkv_tiny_attn=rwkv_tiny_attn, + rwkv_tiny_head=rwkv_tiny_head, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + n_attn=n_attn, + n_ffn=n_ffn, + ) +) # load a trained model # model.load_state_dict(torch.load('trained-xxx.pth').state_dict()) -print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) -tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, - learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, - warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) +print( + "model", + model_type, + "epoch", + n_epoch, + "batchsz", + batch_size, + "betas", + betas, + "eps", + eps, + "wd", + weight_decay, + "ctx", + ctx_len, + "layer", + n_layer, + "head", + n_head, + "embd", + n_embd, + "attn", + n_attn, + "ffn", + n_ffn, +) +tconf = TrainerConfig( + model_type=model_type, + max_epochs=n_epoch, + batch_size=batch_size, + weight_decay=weight_decay, + learning_rate=lr_init, + lr_decay=True, + lr_final=lr_final, + betas=betas, + eps=eps, + warmup_tokens=0, + final_tokens=n_epoch * len(train_dataset) * ctx_len, + num_workers=0, + epoch_save_frequency=epoch_save_frequency, + epoch_save_path=epoch_save_path, +) trainer = Trainer(model, train_dataset, None, tconf) trainer.train() -torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') +torch.save( + model, + "trained-" + + trainer.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + + ".pth", +) diff --git a/RWKV-v2-RNN/run.py b/RWKV-v2-RNN/run.py index a6ee6a2bc..839e8694c 100644 --- a/RWKV-v2-RNN/run.py +++ b/RWKV-v2-RNN/run.py @@ -12,6 +12,7 @@ from torch.nn import functional as F from src.utils import TOKENIZER, Dataset from src.model_run import RWKV_RNN + torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -22,11 +23,11 @@ ctx_len = 1024 n_layer = 6 n_embd = 512 -model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' +model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' # your trained model -MODEL_NAME = 'trained-31' -WORD_NAME = 'vocab' # the .json vocab (generated by train.py +MODEL_NAME = "trained-31" +WORD_NAME = "vocab" # the .json vocab (generated by train.py # ########## Uncomment these to test my 27M params enwik8 model ########## # MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' @@ -36,14 +37,14 @@ # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> all unknown tokens in your context will be denoted by it <-- -UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity +UNKNOWN_CHAR = " " # here we just set it to [space] for simplicity -RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' +RUN_DEVICE = "cpu" # 'cpu' (already very fast) or 'cuda' DEBUG_DEBUG = False # True False - show softmax output ### Step 2: set context ################################################################################ -context = "\nIn the" # ==> this is your prompt +context = "\nIn the" # ==> this is your prompt NUM_TRIALS = 999 LENGTH_PER_TRIAL = 500 @@ -54,56 +55,60 @@ ######################################################################################################## -print(f'Loading {MODEL_NAME}...') +print(f"Loading {MODEL_NAME}...") model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) ######################################################################################################## -if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals(): - print('Evaluating on ' + EVAL_DATA + ' ...') +if "EVAL_DATA" in vars() or "EVAL_DATA" in globals(): + print("Evaluating on " + EVAL_DATA + " ...") - data = open(EVAL_DATA, "r", encoding='utf-8').read() + data = open(EVAL_DATA, "r", encoding="utf-8").read() loss_table = np.zeros(ctx_len) N_SAMPLE = 1000 for iii in range(N_SAMPLE): - pos = np.random.randint(0, len(data) - ctx_len-1) - context = data[pos:pos+ctx_len+1] + pos = np.random.randint(0, len(data) - ctx_len - 1) + context = data[pos : pos + ctx_len + 1] ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] model.clear() - for i in range(1, ctx_len+1): + for i in range(1, ctx_len + 1): x = ctx[:i] out = model.run(x) prob = F.softmax(torch.tensor(out), dim=-1) - loss_table[i-1] += -math.log(prob[ctx[i]]) + loss_table[i - 1] += -math.log(prob[ctx[i]]) - print(f'Tested {iii+1} samples: avg_loss over ctx_len =', - np.mean(loss_table) / (iii+1)) + print( + f"Tested {iii+1} samples: avg_loss over ctx_len =", + np.mean(loss_table) / (iii + 1), + ) exit(0) ######################################################################################################## context = tokenizer.refine_context(context) -print('\nYour prompt has ' + str(len(context)) + ' tokens.') -print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') +print("\nYour prompt has " + str(len(context)) + " tokens.") +print( + "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n" +) for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): t_begin = time.time_ns() src_len = len(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] - print(('-' * 30) + context, end='') + print(("-" * 30) + context, end="") model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] if i == src_len - 1: init_state.out = model.run(x) else: @@ -113,7 +118,7 @@ model.load(init_state) for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): - x = ctx[:i+1] + x = ctx[: i + 1] x = x[-ctx_len:] if i == src_len: @@ -121,13 +126,18 @@ else: out = model.run(x) if DEBUG_DEBUG: - print('model', np.array(x), '==>', np.array( - out), np.max(out), np.min(out)) - - char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, - top_p_usual=top_p, top_p_newline=top_p_newline) + print("model", np.array(x), "==>", np.array(out), np.max(out), np.min(out)) + + char = tokenizer.sample_logits( + out, + x, + ctx_len, + temperature=TEMPERATURE, + top_p_usual=top_p, + top_p_newline=top_p_newline, + ) char = char.item() - print(tokenizer.itos[int(char)], end='', flush=True) + print(tokenizer.itos[int(char)], end="", flush=True) ctx += [char] t_end = time.time_ns() - print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') + print("\n----------", round((t_end - t_begin) / (10**9), 2), end="s ") diff --git a/RWKV-v2-RNN/src/model.py b/RWKV-v2-RNN/src/model.py index afa67ed27..ff76d420e 100644 --- a/RWKV-v2-RNN/src/model.py +++ b/RWKV-v2-RNN/src/model.py @@ -9,18 +9,29 @@ import torch import torch.nn as nn from torch.nn import functional as F + logger = logging.getLogger(__name__) ######################################################################################################## # CUDA Kernel ######################################################################################################## -T_MAX = 1024 # increase this if your ctx_len > 1024 -B_GROUP_FORWARD = 4 # set to 8 for best performance +T_MAX = 1024 # increase this if your ctx_len > 1024 +B_GROUP_FORWARD = 4 # set to 8 for best performance B_GROUP_BACKWARD = 2 # set to 2 for best performance -timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}']) +timex_cuda = load( + name="timex", + sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + f"-DBF={B_GROUP_FORWARD}", + f"-DBB={B_GROUP_BACKWARD}", + ], +) class TimeX(torch.autograd.Function): @@ -29,27 +40,40 @@ def forward(ctx, w, k, B, C, T, eps): ctx.B = B ctx.C = C ctx.T = T - assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0 + assert ( + ctx.T % 4 == 0 + and ctx.T <= T_MAX + and ctx.B % B_GROUP_FORWARD == 0 + and ctx.B % B_GROUP_BACKWARD == 0 + ) w = w.contiguous() k = k.contiguous() ctx.save_for_backward(w, k) - wk = torch.empty((B, C, T), device='cuda', - memory_format=torch.contiguous_format) + wk = torch.empty( + (B, C, T), device="cuda", memory_format=torch.contiguous_format + ) timex_cuda.forward(w, k, wk, eps, B, C, T) return wk @staticmethod def backward(ctx, gwk): - assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0 + assert ( + ctx.T % 4 == 0 + and ctx.T <= T_MAX + and ctx.B % B_GROUP_FORWARD == 0 + and ctx.B % B_GROUP_BACKWARD == 0 + ) w, k = ctx.saved_tensors - gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', - memory_format=torch.contiguous_format) - gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', - memory_format=torch.contiguous_format) - timex_cuda.backward(w, k, gwk.contiguous(), gw, - gk, ctx.B, ctx.C, ctx.T) + gw = torch.empty( + (ctx.B, ctx.C, ctx.T), device="cuda", memory_format=torch.contiguous_format + ) + gk = torch.empty( + (ctx.B, ctx.C, ctx.T), device="cuda", memory_format=torch.contiguous_format + ) + timex_cuda.backward(w, k, gwk.contiguous(), gw, gk, ctx.B, ctx.C, ctx.T) return (gw.sum(dim=0), gk, None, None, None, None) + ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## @@ -60,13 +84,18 @@ def backward(ctx, gwk): RWKV_HEAD_QK_DIM = 256 -def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module +def RWKV_Init( + module, config +): # fancy initialization of all lin & emb layer in the module for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue with torch.no_grad(): - name = '[unknown weight]' - for name, parameter in module.named_parameters(): # find the name of the weight + name = "[unknown weight]" + for ( + name, + parameter, + ) in module.named_parameters(): # find the name of the weight if id(m.weight) == id(parameter): break @@ -76,7 +105,9 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # token emb? scale = 1e-4 else: scale = 0 @@ -86,10 +117,12 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in m.bias.data.zero_() if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # final projection? scale = 0.5 - if hasattr(m, 'scale_init'): + if hasattr(m, "scale_init"): scale = m.scale_init # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) @@ -124,21 +157,26 @@ def __init__(self, config, layer_id): decay_speed = torch.ones(attn_sz, 1) first_sa_layer_id = 1 for h in range(attn_sz): - f1 = f1_begin + (layer_id-first_sa_layer_id) / \ - (config.n_layer-1-first_sa_layer_id) * (f1_end - f1_begin) - f2 = f2_begin + (layer_id-first_sa_layer_id) / \ - (config.n_layer-1-first_sa_layer_id) * (f2_end - f2_begin) + f1 = f1_begin + (layer_id - first_sa_layer_id) / ( + config.n_layer - 1 - first_sa_layer_id + ) * (f1_end - f1_begin) + f2 = f2_begin + (layer_id - first_sa_layer_id) / ( + config.n_layer - 1 - first_sa_layer_id + ) * (f2_end - f2_begin) if layer_id == first_sa_layer_id: f1 += 0.5 - if layer_id == config.n_layer-2: + if layer_id == config.n_layer - 2: f2 = 0.4 - if layer_id == config.n_layer-1: + if layer_id == config.n_layer - 1: f2 = 0.37 - decay_speed[h][0] = math.pow(f2, h / (attn_sz-1) * 7) * f1 - self.time_decay = nn.Parameter(torch.log(decay_speed)) # will use exp(self.time_decay) to ensure time_decay > 0 + decay_speed[h][0] = math.pow(f2, h / (attn_sz - 1) * 7) * f1 + self.time_decay = nn.Parameter( + torch.log(decay_speed) + ) # will use exp(self.time_decay) to ensure time_decay > 0 self.time_curve = torch.tensor( - [-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0) - self.time_curve = self.time_curve.to('cuda') + [-(config.ctx_len - 2 - i) for i in range(config.ctx_len - 1)] + ).unsqueeze(0) + self.time_curve = self.time_curve.to("cuda") self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3)) ############################################################################# @@ -174,7 +212,8 @@ def forward(self, x): kv = k * v self.time_w = torch.cat( - [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1) + [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1 + ) w = torch.exp(self.time_w) wkv = TimeX.apply(w, kv, B, C, T, 0) @@ -217,6 +256,7 @@ def forward(self, x): rkv = torch.sigmoid(self.receptance(x)) * kv return rkv + ######################################################################################################## # The GPT Model with our blocks ######################################################################################################## @@ -239,8 +279,8 @@ def __init__(self, config, layer_id): self.ln1 = nn.LayerNorm(config.n_embd) self.ln2 = nn.LayerNorm(config.n_embd) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(config, layer_id+1000) + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": + self.ffnPre = RWKV_ChannelMix(config, layer_id + 1000) else: self.att = RWKV_TimeMix(config, layer_id) @@ -248,7 +288,7 @@ def __init__(self, config, layer_id): def forward(self, x): x = self.ln1(x) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": x = x + self.ffnPre(x) # better in some cases else: x = x + self.att(x) @@ -265,8 +305,7 @@ def __init__(self, config): self.emb = nn.Embedding(config.vocab_size, config.n_embd) - self.blocks = nn.Sequential(*[Block(config, i) - for i in range(config.n_layer)]) + self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) self.ln_out = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) @@ -275,15 +314,17 @@ def __init__(self, config): self.head_q.scale_init = 0 self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) self.ctx_len = config.ctx_len RWKV_Init(self, config) - logger.info("number of parameters: %e", sum(p.numel() - for p in self.parameters())) + logger.info( + "number of parameters: %e", sum(p.numel() for p in self.parameters()) + ) def get_ctx_len(self): return self.ctx_len @@ -303,24 +344,34 @@ def configure_optimizers(self, train_config): for mn, m in self.named_modules(): # here we disable weight_decay for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name no_decay.add(fpn) param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len( - inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) optim_groups = [ - {"params": [param_dict[pn] - for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, ] optimizer = torch.optim.Adam( - optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas, + eps=train_config.eps, + ) return optimizer diff --git a/RWKV-v2-RNN/src/model_run.py b/RWKV-v2-RNN/src/model_run.py index ecb459e57..8ac44cce2 100644 --- a/RWKV-v2-RNN/src/model_run.py +++ b/RWKV-v2-RNN/src/model_run.py @@ -7,10 +7,10 @@ RWKV_K_EPS = 1e-16 RWKV_HEAD_QK_DIM = 256 -DEBUG_TIME = False # True False - show trained time-coeffs +DEBUG_TIME = False # True False - show trained time-coeffs -class RWKV_RNN(): +class RWKV_RNN: def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): self.RUN_DEVICE = RUN_DEVICE self.model_type = model_type @@ -20,19 +20,18 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) self.w = types.SimpleNamespace() - w = torch.load(MODEL_NAME + '.pth', - map_location=torch.device(RUN_DEVICE)) + w = torch.load(MODEL_NAME + ".pth", map_location=torch.device(RUN_DEVICE)) for x in w.keys(): - if '.time_' in x: + if ".time_" in x: w[x] = w[x].squeeze() - if '.time_decay' in x: + if ".time_decay" in x: w[x] = torch.exp(-torch.exp(w[x])) - if '.time_first' in x: + if ".time_first" in x: w[x] = torch.exp(w[x]) - if DEBUG_TIME and '.time_' in x: + if DEBUG_TIME and ".time_" in x: print(x, w[x].squeeze().cpu().numpy()) - xx = x.split('.') + xx = x.split(".") here = self.w for i in range(len(xx)): if xx[i].isdigit(): @@ -44,7 +43,7 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) @@ -114,22 +113,21 @@ def run(self, ctx): for i in range(self.n_layer): x = self.LN(x, w.blocks[i].ln1) - if i == 0 and self.model_type == 'RWKV-ffnPre': - x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}') + if i == 0 and self.model_type == "RWKV-ffnPre": + x = x + self.FF(x, w.blocks[i].ffnPre, f"ffnPre.{i}") else: - x = x + self.SA(x, w.blocks[i].att, f'att.{i}') + x = x + self.SA(x, w.blocks[i].att, f"att.{i}") x = self.LN(x, w.blocks[i].ln2) - x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}') + x = x + self.FF(x, w.blocks[i].ffn, f"ffn.{i}") x = self.LN(x, w.ln_out) if self.hk == None: self.hk = (w.head_k.weight @ x).unsqueeze(0) else: - self.hk = torch.cat( - [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + self.hk = torch.cat([self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) if self.hk.shape[0] > self.ctx_len: - self.hk = self.hk[-self.ctx_len:, :] + self.hk = self.hk[-self.ctx_len :, :] q = w.head_q.weight @ x diff --git a/RWKV-v2-RNN/src/trainer.py b/RWKV-v2-RNN/src/trainer.py index 19ea1d8e2..1afa9d31b 100644 --- a/RWKV-v2-RNN/src/trainer.py +++ b/RWKV-v2-RNN/src/trainer.py @@ -38,7 +38,7 @@ class TrainerConfig: warmup_tokens = 0 final_tokens = 0 epoch_save_frequency = 0 - epoch_save_path = 'trained-' + epoch_save_path = "trained-" num_workers = 0 # for DataLoader def __init__(self, **kwargs): @@ -47,7 +47,6 @@ def __init__(self, **kwargs): class Trainer: - def __init__(self, model, train_dataset, test_dataset, config): self.model = model self.train_dataset = train_dataset @@ -56,23 +55,37 @@ def __init__(self, model, train_dataset, test_dataset, config): self.avg_loss = -1 self.steps = 0 - if 'wandb' in sys.modules: + if "wandb" in sys.modules: cfg = model.config for k in config.__dict__: setattr(cfg, k, config.__dict__[k]) # combine cfg - wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + - datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) - - self.device = 'cpu' + wandb.init( + project="RWKV-LM", + name=self.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), + config=cfg, + save_code=False, + ) + + self.device = "cpu" if torch.cuda.is_available(): # take over whatever gpus are on the system self.device = torch.cuda.current_device() def get_run_name(self): - raw_model = self.model.module if hasattr( - self.model, "module") else self.model + raw_model = self.model.module if hasattr(self.model, "module") else self.model cfg = raw_model.config - run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ - cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + run_name = ( + str(cfg.vocab_size) + + "-" + + str(cfg.ctx_len) + + "-" + + cfg.model_type + + "-" + + str(cfg.n_layer) + + "-" + + str(cfg.n_embd) + ) return run_name def train(self): @@ -81,21 +94,35 @@ def train(self): optimizer = raw_model.configure_optimizers(config) def run_epoch(split): - is_train = split == 'train' + is_train = split == "train" model.train(is_train) data = self.train_dataset if is_train else self.test_dataset if config.num_workers > 0: - loader = DataLoader(data, shuffle=False, pin_memory=True, - batch_size=config.batch_size, - num_workers=config.num_workers) + loader = DataLoader( + data, + shuffle=False, + pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) else: - loader = DataLoader(data, shuffle=False, - batch_size=config.batch_size, - num_workers=config.num_workers) - - pbar = tqdm(enumerate(loader), total=len( - loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + loader = DataLoader( + data, + shuffle=False, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + pbar = ( + tqdm( + enumerate(loader), + total=len(loader), + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) + if is_train + else enumerate(loader) + ) for it, (x, y) in pbar: x = x.to(self.device) # place data on the correct device @@ -110,7 +137,8 @@ def run_epoch(split): if config.grad_norm_clip > 0: torch.nn.utils.clip_grad_norm_( - model.parameters(), config.grad_norm_clip) + model.parameters(), config.grad_norm_clip + ) optimizer.step() @@ -120,51 +148,67 @@ def run_epoch(split): lr_final_factor = config.lr_final / config.learning_rate if self.tokens < config.warmup_tokens: # linear warmup - lr_mult = lr_final_factor + \ - (1 - lr_final_factor) * float(self.tokens) / \ - float(config.warmup_tokens) + lr_mult = lr_final_factor + (1 - lr_final_factor) * float( + self.tokens + ) / float(config.warmup_tokens) progress = 0 else: # cosine learning rate decay - progress = float(self.tokens - config.warmup_tokens) / float( - max(1, config.final_tokens - config.warmup_tokens)) - lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / - 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 + progress = float( + self.tokens - config.warmup_tokens + ) / float( + max(1, config.final_tokens - config.warmup_tokens) + ) + lr_mult = (0.5 + lr_final_factor / 2) + ( + 0.5 - lr_final_factor / 2 + ) * math.cos( + math.pi * progress + ) # better 1.0 ~ 0.1 lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr else: lr = config.learning_rate now_loss = loss.item() # report progress self.lr = lr - if 'wandb' in sys.modules: - wandb.log({"loss": now_loss}, - step=self.steps * self.config.batch_size) + if "wandb" in sys.modules: + wandb.log( + {"loss": now_loss}, step=self.steps * self.config.batch_size + ) self.steps += 1 if self.avg_loss < 0: self.avg_loss = now_loss else: factor = 1 / (it + 1) - self.avg_loss = self.avg_loss * \ - (1.0 - factor) + now_loss * factor + self.avg_loss = ( + self.avg_loss * (1.0 - factor) + now_loss * factor + ) pbar.set_description( - f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") + f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}" + ) self.tokens = 0 # counter used for learning rate decay for epoch in range(config.max_epochs): - run_epoch('train') + run_epoch("train") log_file.write( - f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n') + f"{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n" + ) log_file.flush() - if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): + if ( + self.config.epoch_save_frequency > 0 + and epoch % self.config.epoch_save_frequency == 0 + ) or (epoch == config.max_epochs - 1): # DataParallel wrappers keep raw model object in .module - raw_model = self.model.module if hasattr( - self.model, "module") else self.model - torch.save(raw_model.state_dict(), - self.config.epoch_save_path + str(epoch+1) + '.pth') + raw_model = ( + self.model.module if hasattr(self.model, "module") else self.model + ) + torch.save( + raw_model.state_dict(), + self.config.epoch_save_path + str(epoch + 1) + ".pth", + ) diff --git a/RWKV-v2-RNN/src/utils.py b/RWKV-v2-RNN/src/utils.py index 480518f07..49d0afee5 100644 --- a/RWKV-v2-RNN/src/utils.py +++ b/RWKV-v2-RNN/src/utils.py @@ -15,7 +15,7 @@ class Dataset(Dataset): def __init__(self, data, ctx_len, epoch_length_fixed): - print('building token list...', end=' ') + print("building token list...", end=" ") unique = sorted(list(set(data))) # print() # for u in unique: @@ -27,11 +27,11 @@ def __init__(self, data, ctx_len, epoch_length_fixed): for u in unique: xxObj[xx] = u xx += 1 - with open('vocab.json', "w", encoding="utf-16") as vocab_file: + with open("vocab.json", "w", encoding="utf-16") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) data_size, vocab_size = len(data), len(unique) - print('data has %d tokens, %d unique.' % (data_size, vocab_size)) + print("data has %d tokens, %d unique." % (data_size, vocab_size)) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} self.ctx_len = ctx_len @@ -45,18 +45,16 @@ def __len__(self): def __getitem__(self, idx): # cheat: pick a random spot in dataset i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) - chunk = self.data[i:i+self.ctx_len+1] + chunk = self.data[i : i + self.ctx_len + 1] dix = [self.stoi[s] for s in chunk] - x = torch.tensor(dix[:-1], dtype=torch.long, - device=torch.device('cuda')) - y = torch.tensor(dix[1:], dtype=torch.long, - device=torch.device('cuda')) + x = torch.tensor(dix[:-1], dtype=torch.long, device=torch.device("cuda")) + y = torch.tensor(dix[1:], dtype=torch.long, device=torch.device("cuda")) return x, y -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -67,24 +65,26 @@ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(torch.tensor(out), dim=-1) - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual diff --git a/RWKV-v2-RNN/train.py b/RWKV-v2-RNN/train.py index e46c0ac0e..65780a8c6 100644 --- a/RWKV-v2-RNN/train.py +++ b/RWKV-v2-RNN/train.py @@ -10,6 +10,7 @@ from src.utils import Dataset import torch import numpy as np + torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -17,17 +18,17 @@ ### Step 1: set training data ########################################################################## datafile = "enwik8" -datafile_encoding = 'utf-8' +datafile_encoding = "utf-8" # datafile_encoding = 'utf-16le' ### Step 2: set model size ############################################################################# -ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024 +ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024 n_layer = 6 n_embd = 512 # 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases) -model_type = 'RWKV' +model_type = "RWKV" ### Step 3: set batch size ############################################################################# @@ -44,7 +45,7 @@ n_epoch = 500 # 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc. epoch_save_frequency = 30 -epoch_save_path = 'trained-' +epoch_save_path = "trained-" epoch_length_fixed = 10000 @@ -54,8 +55,11 @@ # src.utils.set_seed(42) # remember to change seed if you load a model np.set_printoptions(precision=4, suppress=True, linewidth=200) -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) grad_norm_clip = 1.0 warmup_tokens = 0 @@ -69,30 +73,75 @@ # Load data ######################################################################################################## -print('loading data... ' + datafile) -train_dataset = Dataset(open( - datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) +print("loading data... " + datafile) +train_dataset = Dataset( + open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed +) ######################################################################################################## # Train model ######################################################################################################## -if __name__ == '__main__': - - model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, - n_layer=n_layer, n_embd=n_embd)).cuda() +if __name__ == "__main__": + + model = GPT( + GPTConfig( + train_dataset.vocab_size, + train_dataset.ctx_len, + model_type=model_type, + n_layer=n_layer, + n_embd=n_embd, + ) + ).cuda() # # # load a trained model. remember to change random seed # m2 = torch.load('trained-61.pth') # model.load_state_dict(m2) - print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', - betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) - tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, - learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip, - warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) + print( + "model", + model_type, + "epoch", + n_epoch, + "batchsz", + batch_size, + "betas", + betas, + "eps", + eps, + "ctx", + ctx_len, + "layer", + n_layer, + "embd", + n_embd, + ) + tconf = TrainerConfig( + model_type=model_type, + max_epochs=n_epoch, + batch_size=batch_size, + learning_rate=lr_init, + lr_decay=True, + lr_final=lr_final, + betas=betas, + eps=eps, + grad_norm_clip=grad_norm_clip, + warmup_tokens=warmup_tokens, + final_tokens=n_epoch * len(train_dataset) * ctx_len, + num_workers=num_workers, + epoch_save_frequency=epoch_save_frequency, + epoch_save_path=epoch_save_path, + ) trainer = Trainer(model, train_dataset, None, tconf) trainer.train() - torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + - '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') + torch.save( + model.state_dict(), + "trained-" + + str(n_epoch) + + "-" + + trainer.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + + ".pth", + ) diff --git a/RWKV-v3/run.py b/RWKV-v3/run.py index c6862a540..0e2ecba75 100644 --- a/RWKV-v3/run.py +++ b/RWKV-v3/run.py @@ -11,6 +11,7 @@ from torch.nn import functional as F from src.utils import TOKENIZER, Dataset from src.model_run import RWKV_RNN + torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -21,22 +22,22 @@ ctx_len = 1024 n_layer = 6 n_embd = 512 -model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' +model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' # your trained model -MODEL_NAME = 'trained-1' -WORD_NAME = 'vocab' # the .json vocab (generated by train.py +MODEL_NAME = "trained-1" +WORD_NAME = "vocab" # the .json vocab (generated by train.py # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- # --> all unknown tokens in your context will be denoted by it <-- -UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity +UNKNOWN_CHAR = " " # here we just set it to [space] for simplicity -RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' +RUN_DEVICE = "cpu" # 'cpu' (already very fast) or 'cuda' DEBUG_DEBUG = False # True False - show softmax output ### Step 2: set context ################################################################################ -context = "\nIn the" # ==> this is your prompt +context = "\nIn the" # ==> this is your prompt NUM_TRIALS = 999 LENGTH_PER_TRIAL = 500 @@ -47,28 +48,30 @@ ######################################################################################################## -print(f'Loading {MODEL_NAME}...') +print(f"Loading {MODEL_NAME}...") model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) ######################################################################################################## context = tokenizer.refine_context(context) -print('\nYour prompt has ' + str(len(context)) + ' tokens.') -print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n') +print("\nYour prompt has " + str(len(context)) + " tokens.") +print( + "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" +) for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): t_begin = time.time_ns() src_len = len(context) ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] - print(('-' * 30) + context, end='') + print(("-" * 30) + context, end="") model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] if i == src_len - 1: init_state.out = model.run(x) else: @@ -78,7 +81,7 @@ model.load(init_state) for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): - x = ctx[:i+1] + x = ctx[: i + 1] x = x[-ctx_len:] if i == src_len: @@ -86,13 +89,18 @@ else: out = model.run(x) if DEBUG_DEBUG: - print('model', np.array(x), '==>', np.array( - out), np.max(out), np.min(out)) - - char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, - top_p_usual=top_p, top_p_newline=top_p_newline) + print("model", np.array(x), "==>", np.array(out), np.max(out), np.min(out)) + + char = tokenizer.sample_logits( + out, + x, + ctx_len, + temperature=TEMPERATURE, + top_p_usual=top_p, + top_p_newline=top_p_newline, + ) char = char.item() - print(tokenizer.itos[int(char)], end='', flush=True) + print(tokenizer.itos[int(char)], end="", flush=True) ctx += [char] t_end = time.time_ns() - print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') + print("\n----------", round((t_end - t_begin) / (10**9), 2), end="s ") diff --git a/RWKV-v3/src/model.py b/RWKV-v3/src/model.py index 4275aa337..f1fca462b 100644 --- a/RWKV-v3/src/model.py +++ b/RWKV-v3/src/model.py @@ -9,23 +9,36 @@ import torch import torch.nn as nn from torch.nn import functional as F + logger = logging.getLogger(__name__) RWKV_K_CLAMP = 60 # e^60 = 1e26 RWKV_K_EPS = 1e-8 RWKV_HEAD_QK_DIM = 256 -print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print( + f"\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n" +) ######################################################################################################## # CUDA Kernel ######################################################################################################## -T_MAX = 1024 # increase this if your ctx_len > 1024 -B_GROUP_FORWARD = 4 # set to 8 for best performance +T_MAX = 1024 # increase this if your ctx_len > 1024 +B_GROUP_FORWARD = 4 # set to 8 for best performance B_GROUP_BACKWARD = 2 # set to 2 for best performance (sometimes 8 is faster) -timex_cuda = load(name="timex", sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"], - verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}']) +timex_cuda = load( + name="timex", + sources=["cuda/timex_op.cpp", "cuda/timex_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + f"-DBF={B_GROUP_FORWARD}", + f"-DBB={B_GROUP_BACKWARD}", + ], +) class TimeX(torch.autograd.Function): @@ -34,38 +47,57 @@ def forward(ctx, w, k, B, C, T, eps): ctx.B = B ctx.C = C ctx.T = T - assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0 + assert ( + ctx.T % 4 == 0 + and ctx.T <= T_MAX + and ctx.B % B_GROUP_FORWARD == 0 + and ctx.B % B_GROUP_BACKWARD == 0 + ) w = w.contiguous() k = k.contiguous() ctx.save_for_backward(w, k) - wk = torch.empty((B, C, T), device='cuda', - memory_format=torch.contiguous_format) + wk = torch.empty( + (B, C, T), device="cuda", memory_format=torch.contiguous_format + ) timex_cuda.forward(w, k, wk, eps, B, C, T) return wk @staticmethod def backward(ctx, gwk): - assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0 + assert ( + ctx.T % 4 == 0 + and ctx.T <= T_MAX + and ctx.B % B_GROUP_FORWARD == 0 + and ctx.B % B_GROUP_BACKWARD == 0 + ) w, k = ctx.saved_tensors - gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', - memory_format=torch.contiguous_format) - gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda', - memory_format=torch.contiguous_format) - timex_cuda.backward(w, k, gwk.contiguous(), gw, - gk, ctx.B, ctx.C, ctx.T) + gw = torch.empty( + (ctx.B, ctx.C, ctx.T), device="cuda", memory_format=torch.contiguous_format + ) + gk = torch.empty( + (ctx.B, ctx.C, ctx.T), device="cuda", memory_format=torch.contiguous_format + ) + timex_cuda.backward(w, k, gwk.contiguous(), gw, gk, ctx.B, ctx.C, ctx.T) return (gw.sum(dim=0), gk, None, None, None, None) + ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## -def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in the module + +def RWKV_Init( + module, config +): # fancy initialization of all lin & emb layer in the module for m in module.modules(): if not isinstance(m, (nn.Linear, nn.Embedding)): continue with torch.no_grad(): - name = '[unknown weight]' - for name, parameter in module.named_parameters(): # find the name of the weight + name = "[unknown weight]" + for ( + name, + parameter, + ) in module.named_parameters(): # find the name of the weight if id(m.weight) == id(parameter): break @@ -75,7 +107,9 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # token emb? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # token emb? scale = 1e-4 else: scale = 0 @@ -85,10 +119,12 @@ def RWKV_Init(module, config): # fancy initialization of all lin & emb layer in m.bias.data.zero_() if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == config.vocab_size and shape[1] == config.n_embd: # final projection? + if ( + shape[0] == config.vocab_size and shape[1] == config.n_embd + ): # final projection? scale = 0.5 - if hasattr(m, 'scale_init'): + if hasattr(m, "scale_init"): scale = m.scale_init # print(str(shape[0]).ljust(5), str(shape[1]).ljust(5), f'{round(scale,2):g}'.ljust(4), name) @@ -114,32 +150,41 @@ def __init__(self, config, layer_id): attn_sz = config.n_embd - with torch.no_grad(): # fancy init - self.time_curve = torch.tensor([-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0) - self.time_curve = self.time_curve.to('cuda') + with torch.no_grad(): # fancy init + self.time_curve = torch.tensor( + [-(config.ctx_len - 2 - i) for i in range(config.ctx_len - 1)] + ).unsqueeze(0) + self.time_curve = self.time_curve.to("cuda") + + ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 - ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 - # fancy time_decay decay_speed = torch.ones(attn_sz, 1) for h in range(attn_sz): - decay_speed[h][0] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed[h][0] = -5 + 8 * (h / (attn_sz - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first - zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5).unsqueeze(1) - self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3) + zigzag) - + zigzag = ( + torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 + ).unsqueeze(1) + self.time_first = nn.Parameter( + torch.ones(attn_sz, 1) * math.log(0.3) + zigzag + ) + # fancy time_mix x = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): x[0, 0, i] = i / config.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_v = nn.Parameter( + torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) @@ -154,10 +199,10 @@ def __init__(self, config, layer_id): self.output.scale_init = 0 def forward(self, x): - B, T, C = x.size() # x = (Batch,Time,Channel) + B, T, C = x.size() # x = (Batch,Time,Channel) # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1)) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -167,14 +212,15 @@ def forward(self, x): v = self.value(xv).transpose(-1, -2) r = self.receptance(xr) - # RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later) - k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow + # RWKV_K_CLAMP can be removed if the CUDA kernel subtracts the correct k_max for each k (I will do this later) + k = torch.clamp(k, max=RWKV_K_CLAMP) # clamp k to avoid overflow k = torch.exp(k) kv = k * v # Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)] self.time_w = torch.cat( - [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1) + [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1 + ) w = torch.exp(self.time_w) # Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero @@ -194,8 +240,8 @@ def __init__(self, config, layer_id): self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - with torch.no_grad(): # fancy init of time_mix - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 x = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): @@ -224,6 +270,7 @@ def forward(self, x): rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv + ######################################################################################################## # The GPT Model with our blocks ######################################################################################################## @@ -249,8 +296,8 @@ def __init__(self, config, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(config.n_embd) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(config, layer_id+1000) + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": + self.ffnPre = RWKV_ChannelMix(config, layer_id + 1000) else: self.att = RWKV_TimeMix(config, layer_id) @@ -258,8 +305,8 @@ def __init__(self, config, layer_id): def forward(self, x): if self.layer_id == 0: - x = self.ln0(x) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + x = self.ln0(x) + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": x = x + self.ffnPre(self.ln1(x)) # better in some cases else: x = x + self.att(self.ln1(x)) @@ -275,8 +322,7 @@ def __init__(self, config): self.emb = nn.Embedding(config.vocab_size, config.n_embd) - self.blocks = nn.Sequential(*[Block(config, i) - for i in range(config.n_layer)]) + self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) self.ln_out = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) @@ -286,15 +332,17 @@ def __init__(self, config): self.head_q.scale_init = 0 self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) self.ctx_len = config.ctx_len RWKV_Init(self, config) - logger.info("number of parameters: %e", sum(p.numel() - for p in self.parameters())) + logger.info( + "number of parameters: %e", sum(p.numel() for p in self.parameters()) + ) def get_ctx_len(self): return self.ctx_len @@ -314,24 +362,34 @@ def configure_optimizers(self, train_config): for mn, m in self.named_modules(): # here we disable weight_decay for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name no_decay.add(fpn) param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len( - inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) optim_groups = [ - {"params": [param_dict[pn] - for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, ] optimizer = torch.optim.Adam( - optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas, + eps=train_config.eps, + ) return optimizer diff --git a/RWKV-v3/src/model_run.py b/RWKV-v3/src/model_run.py index 44ea131ce..75fc66550 100644 --- a/RWKV-v3/src/model_run.py +++ b/RWKV-v3/src/model_run.py @@ -12,20 +12,23 @@ RWKV_K_CLAMP = 60 RWKV_K_EPS = 1e-8 RWKV_HEAD_QK_DIM = 256 -print(f'\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print( + f"\nRWKV_K_CLAMP {RWKV_K_CLAMP} RWKV_K_EPS {RWKV_K_EPS} RWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n" +) -DEBUG_TIME = False # True False - show trained time-coeffs +DEBUG_TIME = False # True False - show trained time-coeffs ############################################################################################################ RWKV_CFG = types.SimpleNamespace() + class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) @@ -42,22 +45,25 @@ def forward(self, x): k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) - + rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv + class RWKV_TimeMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1)) - self.time_curve = torch.tensor([-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len-1)]).unsqueeze(0) + self.time_curve = torch.tensor( + [-(RWKV_CFG.ctx_len - 2 - i) for i in range(RWKV_CFG.ctx_len - 1)] + ).unsqueeze(0) self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd, 1) * math.log(0.3)) - - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) - self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_v = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) @@ -82,18 +88,25 @@ def forward(self, x): kv = k * v - self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1) + self.time_w = torch.cat( + [ + torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), + self.time_first, + ], + dim=-1, + ) w = torch.exp(self.time_w) - - w = w[:,-T:].unsqueeze(1) - wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C) - wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS + + w = w[:, -T:].unsqueeze(1) + wkv = F.conv1d(nn.ZeroPad2d((T - 1, 0, 0, 0))(kv), w, groups=C) + wk = F.conv1d(nn.ZeroPad2d((T - 1, 0, 0, 0))(k), w, groups=C) + RWKV_K_EPS rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2) - + rwkv = self.output(rwkv) return rwkv + class Block(nn.Module): def __init__(self, layer_id): super().__init__() @@ -104,8 +117,8 @@ def __init__(self, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(layer_id+1000) + if self.layer_id == 0 and RWKV_CFG.model_type == "RWKV-ffnPre": + self.ffnPre = RWKV_ChannelMix(layer_id + 1000) else: self.att = RWKV_TimeMix(layer_id) @@ -114,15 +127,18 @@ def __init__(self, layer_id): def forward(self, x): if self.layer_id == 0: x = self.ln0(x) - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + if self.layer_id == 0 and RWKV_CFG.model_type == "RWKV-ffnPre": x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x + class RWKV_GPT(nn.Module): - def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): + def __init__( + self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len + ): global RWKV_CFG super().__init__() @@ -133,7 +149,7 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_em RWKV_CFG.n_embd = n_embd RWKV_CFG.ctx_len = ctx_len - print('\nloading RWKV-GPT', MODEL_NAME) + print("\nloading RWKV-GPT", MODEL_NAME) self.emb = nn.Embedding(vocab_size, n_embd) @@ -147,18 +163,17 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_em self.head_q.scale_init = 0 self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(ctx_len, ctx_len))) + self.register_buffer("copy_mask", torch.tril(torch.ones(ctx_len, ctx_len))) self.ctx_len = ctx_len self.eval() - self.load_state_dict(torch.load(MODEL_NAME + '.pth')) + self.load_state_dict(torch.load(MODEL_NAME + ".pth")) self.eval() def forward(self, idx): B, T = idx.size() assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." - + x = self.emb(idx) x = self.blocks(x) x = self.ln_out(x) @@ -172,13 +187,15 @@ def forward(self, idx): c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).float() x = self.head(x) + c else: - x = self.head(x) + x = self.head(x) return x + ############################################################################################################ -class RWKV_RNN(): + +class RWKV_RNN: def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): self.RUN_DEVICE = RUN_DEVICE self.model_type = model_type @@ -188,19 +205,18 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) self.w = types.SimpleNamespace() - w = torch.load(MODEL_NAME + '.pth', - map_location=torch.device(RUN_DEVICE)) + w = torch.load(MODEL_NAME + ".pth", map_location=torch.device(RUN_DEVICE)) for x in w.keys(): - if '.time_' in x: + if ".time_" in x: w[x] = w[x].squeeze() - if '.time_decay' in x: + if ".time_decay" in x: w[x] = torch.exp(-torch.exp(w[x])) - if '.time_first' in x: + if ".time_first" in x: w[x] = torch.exp(w[x]) - if DEBUG_TIME and '.time_' in x: + if DEBUG_TIME and ".time_" in x: print(x, w[x].squeeze().cpu().numpy()) - xx = x.split('.') + xx = x.split(".") here = self.w for i in range(len(xx)): if xx[i].isdigit(): @@ -212,7 +228,7 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) @@ -287,11 +303,15 @@ def run(self, ctx): for i in range(self.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - if i == 0 and self.model_type == 'RWKV-ffnPre': - x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + if i == 0 and self.model_type == "RWKV-ffnPre": + x = x + self.FF( + self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f"ffnPre.{i}" + ) else: - x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') - x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + x = x + self.SA( + self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f"att.{i}" + ) + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f"ffn.{i}") x = self.LN(x, w.ln_out) @@ -300,9 +320,10 @@ def run(self, ctx): self.hk = (w.head_k.weight @ x).unsqueeze(0) else: self.hk = torch.cat( - [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0 + ) if self.hk.shape[0] > self.ctx_len: - self.hk = self.hk[-self.ctx_len:, :] + self.hk = self.hk[-self.ctx_len :, :] q = w.head_q.weight @ x diff --git a/RWKV-v3/src/trainer.py b/RWKV-v3/src/trainer.py index 418d72e02..52addd564 100644 --- a/RWKV-v3/src/trainer.py +++ b/RWKV-v3/src/trainer.py @@ -38,7 +38,7 @@ class TrainerConfig: warmup_tokens = 0 final_tokens = 0 epoch_save_frequency = 0 - epoch_save_path = 'trained-' + epoch_save_path = "trained-" num_workers = 0 # for DataLoader def __init__(self, **kwargs): @@ -47,7 +47,6 @@ def __init__(self, **kwargs): class Trainer: - def __init__(self, model, train_dataset, test_dataset, config): self.model = model self.train_dataset = train_dataset @@ -56,23 +55,37 @@ def __init__(self, model, train_dataset, test_dataset, config): self.avg_loss = -1 self.steps = 0 - if 'wandb' in sys.modules: + if "wandb" in sys.modules: cfg = model.config for k in config.__dict__: setattr(cfg, k, config.__dict__[k]) # combine cfg - wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + - datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) - - self.device = 'cpu' + wandb.init( + project="RWKV-LM", + name=self.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), + config=cfg, + save_code=False, + ) + + self.device = "cpu" if torch.cuda.is_available(): # take over whatever gpus are on the system self.device = torch.cuda.current_device() def get_run_name(self): - raw_model = self.model.module if hasattr( - self.model, "module") else self.model + raw_model = self.model.module if hasattr(self.model, "module") else self.model cfg = raw_model.config - run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ - cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + run_name = ( + str(cfg.vocab_size) + + "-" + + str(cfg.ctx_len) + + "-" + + cfg.model_type + + "-" + + str(cfg.n_layer) + + "-" + + str(cfg.n_embd) + ) return run_name def train(self): @@ -81,21 +94,35 @@ def train(self): optimizer = raw_model.configure_optimizers(config) def run_epoch(split): - is_train = split == 'train' + is_train = split == "train" model.train(is_train) data = self.train_dataset if is_train else self.test_dataset if config.num_workers > 0: - loader = DataLoader(data, shuffle=False, pin_memory=True, - batch_size=config.batch_size, - num_workers=config.num_workers) + loader = DataLoader( + data, + shuffle=False, + pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) else: - loader = DataLoader(data, shuffle=False, - batch_size=config.batch_size, - num_workers=config.num_workers) - - pbar = tqdm(enumerate(loader), total=len( - loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + loader = DataLoader( + data, + shuffle=False, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + pbar = ( + tqdm( + enumerate(loader), + total=len(loader), + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) + if is_train + else enumerate(loader) + ) for it, (x, y) in pbar: x = x.to(self.device) # place data on the correct device @@ -110,7 +137,8 @@ def run_epoch(split): if config.grad_norm_clip > 0: torch.nn.utils.clip_grad_norm_( - model.parameters(), config.grad_norm_clip) + model.parameters(), config.grad_norm_clip + ) optimizer.step() @@ -120,52 +148,68 @@ def run_epoch(split): lr_final_factor = config.lr_final / config.learning_rate if self.tokens < config.warmup_tokens: # linear warmup - lr_mult = lr_final_factor + \ - (1 - lr_final_factor) * float(self.tokens) / \ - float(config.warmup_tokens) + lr_mult = lr_final_factor + (1 - lr_final_factor) * float( + self.tokens + ) / float(config.warmup_tokens) progress = 0 else: # exponential learning rate decay - progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + progress = float( + self.tokens - config.warmup_tokens + ) / float( + max(1, config.final_tokens - config.warmup_tokens) + ) if progress >= 1: lr_mult = lr_final_factor else: - lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) + lr_mult = math.exp( + math.log(lr_final_factor) * pow(progress, 1) + ) lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr else: lr = config.learning_rate now_loss = loss.item() # report progress self.lr = lr - if 'wandb' in sys.modules: - wandb.log({"loss": now_loss}, - step=self.steps * self.config.batch_size) + if "wandb" in sys.modules: + wandb.log( + {"loss": now_loss}, step=self.steps * self.config.batch_size + ) self.steps += 1 if self.avg_loss < 0: self.avg_loss = now_loss else: factor = 1 / (it + 1) - self.avg_loss = self.avg_loss * \ - (1.0 - factor) + now_loss * factor + self.avg_loss = ( + self.avg_loss * (1.0 - factor) + now_loss * factor + ) pbar.set_description( - f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") + f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}" + ) self.tokens = 0 # counter used for learning rate decay for epoch in range(config.max_epochs): - run_epoch('train') + run_epoch("train") log_file.write( - f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n') + f"{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n" + ) log_file.flush() - if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): + if ( + self.config.epoch_save_frequency > 0 + and epoch % self.config.epoch_save_frequency == 0 + ) or (epoch == config.max_epochs - 1): # DataParallel wrappers keep raw model object in .module - raw_model = self.model.module if hasattr( - self.model, "module") else self.model - torch.save(raw_model.state_dict(), - self.config.epoch_save_path + str(epoch+1) + '.pth') + raw_model = ( + self.model.module if hasattr(self.model, "module") else self.model + ) + torch.save( + raw_model.state_dict(), + self.config.epoch_save_path + str(epoch + 1) + ".pth", + ) diff --git a/RWKV-v3/src/utils.py b/RWKV-v3/src/utils.py index 42e9f47b9..b43672c42 100644 --- a/RWKV-v3/src/utils.py +++ b/RWKV-v3/src/utils.py @@ -15,7 +15,7 @@ class Dataset(Dataset): def __init__(self, data, ctx_len, epoch_length_fixed): - print('building token list...', end=' ') + print("building token list...", end=" ") unique = sorted(list(set(data))) # print() # for u in unique: @@ -27,11 +27,11 @@ def __init__(self, data, ctx_len, epoch_length_fixed): for u in unique: xxObj[xx] = u xx += 1 - with open('vocab.json', "w", encoding="utf-16") as vocab_file: + with open("vocab.json", "w", encoding="utf-16") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) data_size, vocab_size = len(data), len(unique) - print('data has %d tokens, %d unique.' % (data_size, vocab_size)) + print("data has %d tokens, %d unique." % (data_size, vocab_size)) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} self.ctx_len = ctx_len @@ -45,18 +45,16 @@ def __len__(self): def __getitem__(self, idx): # cheat: pick a random spot in dataset i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) - chunk = self.data[i:i+self.ctx_len+1] + chunk = self.data[i : i + self.ctx_len + 1] dix = [self.stoi[s] for s in chunk] - x = torch.tensor(dix[:-1], dtype=torch.long, - device=torch.device('cuda')) - y = torch.tensor(dix[1:], dtype=torch.long, - device=torch.device('cuda')) + x = torch.tensor(dix[:-1], dtype=torch.long, device=torch.device("cuda")) + y = torch.tensor(dix[1:], dtype=torch.long, device=torch.device("cuda")) return x, y -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -67,24 +65,26 @@ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(torch.tensor(out), dim=-1) - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual diff --git a/RWKV-v3/train.py b/RWKV-v3/train.py index 1a07cad25..3a9154bc3 100644 --- a/RWKV-v3/train.py +++ b/RWKV-v3/train.py @@ -19,27 +19,30 @@ import numpy as np np.set_printoptions(precision=4, suppress=True, linewidth=200) -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True ### Step 1: set training data ########################################################################## -datafile = "../data/enwik8" # your data -datafile_encoding = 'utf-8' +datafile = "../data/enwik8" # your data +datafile_encoding = "utf-8" # datafile_encoding = 'utf-16le' ### Step 2: set model size ############################################################################# # ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2 -ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024 +ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024 n_layer = 6 n_embd = 512 # 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases) -model_type = 'RWKV' +model_type = "RWKV" # ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py # set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss @@ -61,10 +64,10 @@ # 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run. # 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999). # 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training. -# +# # For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4. -lr_init = 8e-4 # we can use larger lr because of preLN +lr_init = 8e-4 # we can use larger lr because of preLN lr_final = 1e-5 # the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens) @@ -73,7 +76,7 @@ # 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ... epoch_save_frequency = 10 -epoch_save_path = 'trained-' +epoch_save_path = "trained-" ######################################################################################################## @@ -89,30 +92,75 @@ # Load data ######################################################################################################## -print('loading data... ' + datafile) -train_dataset = Dataset(open( - datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) +print("loading data... " + datafile) +train_dataset = Dataset( + open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed +) ######################################################################################################## # Train model ######################################################################################################## -if __name__ == '__main__': - - model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, - n_layer=n_layer, n_embd=n_embd)).cuda() +if __name__ == "__main__": + + model = GPT( + GPTConfig( + train_dataset.vocab_size, + train_dataset.ctx_len, + model_type=model_type, + n_layer=n_layer, + n_embd=n_embd, + ) + ).cuda() ### ---> load a trained model <--- # m2 = torch.load('trained-61.pth') # model.load_state_dict(m2) - print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', - betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) - tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, - learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip, - warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) + print( + "model", + model_type, + "epoch", + n_epoch, + "batchsz", + batch_size, + "betas", + betas, + "eps", + eps, + "ctx", + ctx_len, + "layer", + n_layer, + "embd", + n_embd, + ) + tconf = TrainerConfig( + model_type=model_type, + max_epochs=n_epoch, + batch_size=batch_size, + learning_rate=lr_init, + lr_decay=True, + lr_final=lr_final, + betas=betas, + eps=eps, + grad_norm_clip=grad_norm_clip, + warmup_tokens=warmup_tokens, + final_tokens=n_epoch * len(train_dataset) * ctx_len, + num_workers=num_workers, + epoch_save_frequency=epoch_save_frequency, + epoch_save_path=epoch_save_path, + ) trainer = Trainer(model, train_dataset, None, tconf) trainer.train() - torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + - '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') + torch.save( + model.state_dict(), + "trained-" + + str(n_epoch) + + "-" + + trainer.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + + ".pth", + ) diff --git a/RWKV-v3/verify.py b/RWKV-v3/verify.py index fd911510e..4f90eb341 100644 --- a/RWKV-v3/verify.py +++ b/RWKV-v3/verify.py @@ -5,11 +5,13 @@ # this is for verifying the results of different models and make sure they agree with each other import numpy as np + np.set_printoptions(precision=4, suppress=True, linewidth=200) import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" -RUN_DEVICE = 'cuda' +RUN_DEVICE = "cuda" import torch from src.model_run import RWKV_RNN, RWKV_GPT @@ -18,48 +20,63 @@ ctx_len = 1024 n_layer = 6 n_embd = 512 -model_type = 'RWKV' +model_type = "RWKV" -model_name = 'trained-1' +model_name = "trained-1" from src.utils import TOKENIZER -tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ') + +tokenizer = TOKENIZER("vocab", UNKNOWN_CHAR=" ") ######################################################################################################## -model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() -print('loading ' + model_name) -m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE) +model_train = GPT( + GPTConfig( + tokenizer.vocab_size, + ctx_len, + model_type=model_type, + n_layer=n_layer, + n_embd=n_embd, + ) +).cuda() +print("loading " + model_name) +m2 = torch.load(model_name + ".pth", map_location=RUN_DEVICE) model_train.load_state_dict(m2) model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) -model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda() +model_gpt = RWKV_GPT( + model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len +).cuda() ######################################################################################################## -context = '\nIn a' +context = "\nIn a" ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] -print(f'input len {len(ctx)} data {ctx}') +print(f"input len {len(ctx)} data {ctx}") ######################################################################################################## -print('\nRWKV-GPT output') +print("\nRWKV-GPT output") out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy() print(out) -print('\nRWKV-RNN output') +print("\nRWKV-RNN output") model_rnn.clear() src_len = len(ctx) for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] out = model_rnn.run(x) if i < 3 or i >= src_len - 3: print(torch.tensor(out).detach().cpu().numpy()) if i == 2: - print('...') - -print('\nRWKV-train output') -ctx += [0] * (ctx_len - src_len) # pad to ctx_len -ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD) -out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy() -print(out, '\n') + print("...") + +print("\nRWKV-train output") +ctx += [0] * (ctx_len - src_len) # pad to ctx_len +ctx = [ + ctx +] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD) +out = ( + model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy() +) +print(out, "\n") diff --git a/RWKV-v4/run.py b/RWKV-v4/run.py index 1c1c2fbd4..1e176096d 100644 --- a/RWKV-v4/run.py +++ b/RWKV-v4/run.py @@ -10,6 +10,7 @@ import torch from torch.nn import functional as F from src.utils import TOKENIZER, Dataset + torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -17,36 +18,39 @@ ######################################################################################################## # Step 1: set model -# +# # Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch. # # Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models. ######################################################################################################## -TOKEN_MODE = 'char' # char / bpe / pile +TOKEN_MODE = "char" # char / bpe / pile n_layer = 6 n_embd = 512 ctx_len = 1024 -if TOKEN_MODE == 'char': - MODEL_NAME = 'trained-500' # your trained model - WORD_NAME = 'vocab' # the .json vocab (generated by train.py) +if TOKEN_MODE == "char": + MODEL_NAME = "trained-500" # your trained model + WORD_NAME = "vocab" # the .json vocab (generated by train.py) # set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it - UNKNOWN_CHAR = ' ' # here we just set it to ' ' for simplicity - -elif TOKEN_MODE == 'bpe': - MODEL_NAME = 'trained-500' # your trained model - WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] for your BPE model + UNKNOWN_CHAR = " " # here we just set it to ' ' for simplicity + +elif TOKEN_MODE == "bpe": + MODEL_NAME = "trained-500" # your trained model + WORD_NAME = [ + "model-vocab.json", + "model-merges.txt", + ] # [vocab, merge] for your BPE model UNKNOWN_CHAR = None -elif TOKEN_MODE == 'pile': - WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] +elif TOKEN_MODE == "pile": + WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"] UNKNOWN_CHAR = None - #---> you can set MODEL_NAME to your fine-tuned model <--- + # ---> you can set MODEL_NAME to your fine-tuned model <--- - MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023' + MODEL_NAME = "RWKV-4-Pile-169M-20220807-8023" # MODEL_NAME = 'trained-11' n_layer = 12 n_embd = 768 @@ -60,11 +64,13 @@ # MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040' # n_layer = 24 # n_embd = 2048 - # ctx_len = 1024 + # ctx_len = 1024 -os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment) -os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda' -model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' +os.environ[ + "RWKV_FLOAT_MODE" +] = "fp32" # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment) +os.environ["RWKV_RUN_DEVICE"] = "cpu" # 'cpu' (already very fast) or 'cuda' +model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' ######################################################################################################## # Step 2: set prompt & sampling stuffs @@ -73,22 +79,25 @@ # context = 'A' # context = "\nIn the" # context = '\nSugar:' -context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." NUM_TRIALS = 999 LENGTH_PER_TRIAL = 333 TEMPERATURE = 1.0 top_p = 0.7 -top_p_newline = 0.9 # only used in TOKEN_MODE = char +top_p_newline = 0.9 # only used in TOKEN_MODE = char DEBUG_DEBUG = False # True False --> show softmax output ######################################################################################################## -print(f'Loading {MODEL_NAME}...') +print(f"Loading {MODEL_NAME}...") from src.model_run import RWKV_RNN -model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len) + +model = RWKV_RNN( + MODEL_NAME, os.environ["RWKV_RUN_DEVICE"], model_type, n_layer, n_embd, ctx_len +) tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) ######################################################################################################## @@ -101,18 +110,20 @@ src_len = len(ctx) src_ctx = ctx.copy() -print('\nYour prompt has ' + str(src_len) + ' tokens.') -print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n') +print("\nYour prompt has " + str(src_len) + " tokens.") +print( + "\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n" +) for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): t_begin = time.time_ns() - print(('-' * 30) + context, end='') + print(("-" * 30) + context, end="") ctx = src_ctx.copy() model.clear() if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] if i == src_len - 1: init_state.out = model.run(x) else: @@ -122,7 +133,7 @@ model.load(init_state) for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): - x = ctx[:i+1] + x = ctx[: i + 1] x = x[-ctx_len:] if i == src_len: @@ -130,20 +141,25 @@ else: out = model.run(x) if DEBUG_DEBUG: - print('model', np.array(x), '==>', np.array( - out), np.max(out), np.min(out)) + print("model", np.array(x), "==>", np.array(out), np.max(out), np.min(out)) - if TOKEN_MODE == 'pile': + if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> - char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, - top_p_usual=top_p, top_p_newline=top_p_newline) + char = tokenizer.sample_logits( + out, + x, + ctx_len, + temperature=TEMPERATURE, + top_p_usual=top_p, + top_p_newline=top_p_newline, + ) char = char.item() if tokenizer.charMode: - print(tokenizer.itos[int(char)], end='', flush=True) + print(tokenizer.itos[int(char)], end="", flush=True) else: - print(tokenizer.tokenizer.decode(int(char)), end='', flush=True) + print(tokenizer.tokenizer.decode(int(char)), end="", flush=True) ctx += [char] t_end = time.time_ns() - print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') + print("\n----------", round((t_end - t_begin) / (10**9), 2), end="s ") diff --git a/RWKV-v4/src/binidx.py b/RWKV-v4/src/binidx.py index ce6cfe272..97c34186e 100644 --- a/RWKV-v4/src/binidx.py +++ b/RWKV-v4/src/binidx.py @@ -7,6 +7,7 @@ from functools import lru_cache from itertools import accumulate + def print_rank_0(*message): """If distributed is initialized print only on rank 0.""" if torch.distributed.is_initialized(): @@ -15,12 +16,14 @@ def print_rank_0(*message): else: print(*message, flush=True) + def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass + dtypes = { 1: np.uint8, 2: np.int8, @@ -32,18 +35,22 @@ def _warmup_mmap_file(path): 8: np.uint16, } + def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) + def index_file_path(prefix_path): return prefix_path + ".idx" + def data_file_path(prefix_path): return prefix_path + ".bin" + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" @@ -164,8 +171,7 @@ def __getitem__(self, idx): elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: - raise ValueError( - "Slices into indexed_dataset must be contiguous") + raise ValueError("Slices into indexed_dataset must be contiguous") ptr = self._index._pointers[start] sizes = self._index._sizes[idx] offsets = list(accumulate(sizes)) diff --git a/RWKV-v4/src/model.py b/RWKV-v4/src/model.py index cbda40097..db995cb7f 100644 --- a/RWKV-v4/src/model.py +++ b/RWKV-v4/src/model.py @@ -8,21 +8,24 @@ import torch import torch.nn as nn from torch.nn import functional as F + try: from deepspeed.ops.adam import FusedAdam except: - pass # some poor windows users cant install deepspeed + pass # some poor windows users cant install deepspeed logger = logging.getLogger(__name__) RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print(f"\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n") + class L2Wrap(torch.autograd.Function): @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss + @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] @@ -33,16 +36,30 @@ def backward(ctx, grad_output): gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) + ######################################################################################################## # CUDA Kernel ######################################################################################################## -T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] +T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + +wkv_cuda = load( + name="wkv", + sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + f"-DTmax={T_MAX}", + ], +) + class WKV(torch.autograd.Function): @staticmethod @@ -52,7 +69,7 @@ def forward(ctx, B, T, C, w, u, k, v): ctx.C = C assert T <= T_MAX assert B * C % min(C, 1024) == 0 - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() @@ -63,13 +80,13 @@ def forward(ctx, B, T, C, w, u, k, v): k = k.float().contiguous() v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: return y - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return y.half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() @staticmethod @@ -80,33 +97,48 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 1024) == 0 w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda').contiguous() - gu = torch.zeros((B, C), device='cuda').contiguous() - gk = torch.zeros((B, T, C), device='cuda').contiguous() - gv = torch.zeros((B, T, C), device='cuda').contiguous() - if '32' in os.environ['RWKV_FLOAT_MODE']: + gw = torch.zeros((B, C), device="cuda").contiguous() + gu = torch.zeros((B, C), device="cuda").contiguous() + gk = torch.zeros((B, T, C), device="cuda").contiguous() + gv = torch.zeros((B, T, C), device="cuda").contiguous() + if "32" in os.environ["RWKV_FLOAT_MODE"]: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) else: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv + ) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: return (None, None, None, gw, gu, gk, gv) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return ( + None, + None, + None, + gw.bfloat16(), + gu.bfloat16(), + gk.bfloat16(), + gv.bfloat16(), + ) + def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## + def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model print("\n[--> first run, init model params (very slow for large models) <--]") - print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") + print( + "[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n" + ) for mm in model.modules(): if "RecursiveScriptModule" in str(type(mm)): @@ -123,7 +155,10 @@ def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in th ww = m.weight with torch.no_grad(): name = "[unknown weight]" - for name, parameter in model.named_parameters(): # find the name of the weight + for ( + name, + parameter, + ) in model.named_parameters(): # find the name of the weight if id(ww) == id(parameter): break @@ -133,7 +168,9 @@ def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in th if isinstance(m, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? + if ( + shape[0] == args.vocab_size and shape[1] == args.n_embd + ): # token emb? scale = 1e-4 else: scale = 0 @@ -141,7 +178,9 @@ def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in th if isinstance(m, nn.Linear): if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? + if ( + shape[0] == args.vocab_size and shape[1] == args.n_embd + ): # final projection? scale = 0.5 if hasattr(m, "scale_init"): @@ -170,29 +209,33 @@ def __init__(self, config, layer_id): attn_sz = config.n_embd - with torch.no_grad(): # fancy init - ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 - + with torch.no_grad(): # fancy init + ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 + # fancy time_decay decay_speed = torch.ones(attn_sz) for h in range(attn_sz): - decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first - zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) + zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) - + # fancy time_mix x = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): x[0, 0, i] = i / config.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_v = nn.Parameter( + torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) - + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(config.n_embd, attn_sz, bias=False) @@ -223,7 +266,7 @@ def jit_func(self, x): return sr, k, v def forward(self, x): - B, T, C = x.size() # x = (Batch,Time,Channel) + B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) @@ -239,8 +282,8 @@ def __init__(self, config, layer_id): self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - with torch.no_grad(): # fancy init of time_mix - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 x = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): @@ -270,6 +313,7 @@ def forward(self, x): rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv + ######################################################################################################## # The GPT Model with our blocks ######################################################################################################## @@ -295,7 +339,7 @@ def __init__(self, config, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(config.n_embd) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": self.ffnPre = RWKV_ChannelMix(config, 0) else: self.att = RWKV_TimeMix(config, layer_id) @@ -304,8 +348,8 @@ def __init__(self, config, layer_id): def forward(self, x): if self.layer_id == 0: - x = self.ln0(x) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + x = self.ln0(x) + if self.layer_id == 0 and self.config.model_type == "RWKV-ffnPre": x = x + self.ffnPre(self.ln1(x)) # better in some cases else: x = x + self.att(self.ln1(x)) @@ -321,8 +365,7 @@ def __init__(self, config): self.emb = nn.Embedding(config.vocab_size, config.n_embd) - self.blocks = nn.Sequential(*[Block(config, i) - for i in range(config.n_layer)]) + self.blocks = nn.Sequential(*[Block(config, i) for i in range(config.n_layer)]) self.ln_out = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) @@ -332,19 +375,21 @@ def __init__(self, config): self.head_q.scale_init = 0 self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(config.ctx_len, config.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)) + ) self.ctx_len = config.ctx_len try: - if os.environ['RWKV_LOAD_MODEL'] == str(False): - RWKV_Init(self, config) + if os.environ["RWKV_LOAD_MODEL"] == str(False): + RWKV_Init(self, config) except: pass - logger.info("number of parameters: %e", sum(p.numel() - for p in self.parameters())) + logger.info( + "number of parameters: %e", sum(p.numel() for p in self.parameters()) + ) def get_ctx_len(self): return self.ctx_len @@ -362,20 +407,38 @@ def configure_optimizers(self, train_config): for mn, m in self.named_modules(): # here we disable weight_decay for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name no_decay.add(fpn) param_dict = {pn: p for pn, p in self.named_parameters()} optim_groups = [ - {"params": [param_dict[pn] - for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, ] try: - optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + optimizer = FusedAdam( + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas, + eps=train_config.eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) except: - print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') - optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + print( + "\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n" + ) + optimizer = torch.optim.Adam( + optim_groups, + lr=train_config.learning_rate, + betas=train_config.betas, + eps=train_config.eps, + ) return optimizer @@ -395,12 +458,12 @@ def forward(self, idx, targets=None): k = self.head_k(x)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - - if '32' in os.environ['RWKV_FLOAT_MODE']: + + if "32" in os.environ["RWKV_FLOAT_MODE"]: c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() x = self.head(x) + c @@ -409,6 +472,8 @@ def forward(self, idx, targets=None): loss = None if targets is not None: - loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) + loss = F.cross_entropy( + x.view(-1, x.size(-1)), targets.to(x.device).view(-1) + ) return L2Wrap.apply(loss, x) diff --git a/RWKV-v4/src/model_run.py b/RWKV-v4/src/model_run.py index 16c7e5dff..4c5d88024 100644 --- a/RWKV-v4/src/model_run.py +++ b/RWKV-v4/src/model_run.py @@ -10,21 +10,33 @@ import torch.nn as nn RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') +print(f"\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n") -DEBUG_TIME = False # True False - show trained time-coeffs +DEBUG_TIME = False # True False - show trained time-coeffs ######################################################################################################## # CUDA Kernel ######################################################################################################## -if os.environ['RWKV_RUN_DEVICE'] == 'cuda': - T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] +if os.environ["RWKV_RUN_DEVICE"] == "cuda": + T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load - wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + + wkv_cuda = load( + name="wkv", + sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + f"-DTmax={T_MAX}", + ], + ) class WKV(torch.autograd.Function): @staticmethod @@ -34,7 +46,7 @@ def forward(ctx, B, T, C, w, u, k, v): ctx.C = C assert T <= T_MAX assert B * C % min(C, 1024) == 0 - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() @@ -45,13 +57,15 @@ def forward(ctx, B, T, C, w, u, k, v): k = k.float().contiguous() v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + y = torch.empty( + (B, T, C), device="cuda", memory_format=torch.contiguous_format + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: return y - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return y.half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() @staticmethod @@ -62,36 +76,48 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 1024) == 0 w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda').contiguous() - gu = torch.zeros((B, C), device='cuda').contiguous() - gk = torch.zeros((B, T, C), device='cuda').contiguous() - gv = torch.zeros((B, T, C), device='cuda').contiguous() - if '32' in os.environ['RWKV_FLOAT_MODE']: + gw = torch.zeros((B, C), device="cuda").contiguous() + gu = torch.zeros((B, C), device="cuda").contiguous() + gk = torch.zeros((B, T, C), device="cuda").contiguous() + gv = torch.zeros((B, T, C), device="cuda").contiguous() + if "32" in os.environ["RWKV_FLOAT_MODE"]: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) else: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv + ) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: return (None, None, None, gw, gu, gk, gv) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return ( + None, + None, + None, + gw.bfloat16(), + gu.bfloat16(), + gk.bfloat16(), + gv.bfloat16(), + ) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + ############################################################################################################ RWKV_CFG = types.SimpleNamespace() + class RWKV_ChannelMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) @@ -108,21 +134,22 @@ def forward(self, x): k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) - + rkv = torch.sigmoid(self.receptance(xr)) * kv return rkv + class RWKV_TimeMix(nn.Module): def __init__(self, layer_id): super().__init__() self.layer_id = layer_id self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd)) self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3)) - - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) - self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_v = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) @@ -142,11 +169,14 @@ def forward(self, x): v = self.value(xv) r = self.receptance(xr) - rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) - + rwkv = torch.sigmoid(r) * RUN_CUDA( + B, T, C, self.time_decay, self.time_first, k, v + ) + rwkv = self.output(rwkv) return rwkv + class Block(nn.Module): def __init__(self, layer_id): super().__init__() @@ -157,8 +187,8 @@ def __init__(self, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(layer_id+1000) + if self.layer_id == 0 and RWKV_CFG.model_type == "RWKV-ffnPre": + self.ffnPre = RWKV_ChannelMix(layer_id + 1000) else: self.att = RWKV_TimeMix(layer_id) @@ -167,15 +197,18 @@ def __init__(self, layer_id): def forward(self, x): if self.layer_id == 0: x = self.ln0(x) - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + if self.layer_id == 0 and RWKV_CFG.model_type == "RWKV-ffnPre": x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x + class RWKV_GPT(nn.Module): - def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): + def __init__( + self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len + ): global RWKV_CFG super().__init__() @@ -186,7 +219,7 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_em RWKV_CFG.n_embd = n_embd RWKV_CFG.ctx_len = ctx_len - print('\nloading RWKV-GPT', MODEL_NAME) + print("\nloading RWKV-GPT", MODEL_NAME) self.emb = nn.Embedding(vocab_size, n_embd) @@ -200,18 +233,17 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_em self.head_q.scale_init = 0 self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(ctx_len, ctx_len))) + self.register_buffer("copy_mask", torch.tril(torch.ones(ctx_len, ctx_len))) self.ctx_len = ctx_len self.eval() - self.load_state_dict(torch.load(MODEL_NAME + '.pth')) + self.load_state_dict(torch.load(MODEL_NAME + ".pth")) self.eval() def forward(self, idx): B, T = idx.size() assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." - + x = self.emb(idx) x = self.blocks(x) x = self.ln_out(x) @@ -222,22 +254,24 @@ def forward(self, idx): c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - if '32' in os.environ['RWKV_FLOAT_MODE']: + if "32" in os.environ["RWKV_FLOAT_MODE"]: c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16() x = self.head(x) + c else: - x = self.head(x) + x = self.head(x) return x + ############################################################################################################ -class RWKV_RNN(): # this is running in FP32 at this moment + +class RWKV_RNN: # this is running in FP32 at this moment def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): self.RUN_DEVICE = RUN_DEVICE self.model_type = model_type @@ -247,18 +281,17 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) self.w = types.SimpleNamespace() - w = torch.load(MODEL_NAME + '.pth', - map_location=torch.device(RUN_DEVICE)) + w = torch.load(MODEL_NAME + ".pth", map_location=torch.device(RUN_DEVICE)) for x in w.keys(): w[x] = w[x].float() - if '.time_' in x: + if ".time_" in x: w[x] = w[x].squeeze() - if '.time_decay' in x: + if ".time_decay" in x: w[x] = -torch.exp(w[x]) - if DEBUG_TIME and '.time_' in x: + if DEBUG_TIME and ".time_" in x: print(x, w[x].squeeze().cpu().numpy()) - xx = x.split('.') + xx = x.split(".") here = self.w for i in range(len(xx)): if xx[i].isdigit(): @@ -270,7 +303,7 @@ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) @@ -360,11 +393,15 @@ def run(self, ctx): for i in range(self.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - if i == 0 and self.model_type == 'RWKV-ffnPre': - x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + if i == 0 and self.model_type == "RWKV-ffnPre": + x = x + self.FF( + self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f"ffnPre.{i}" + ) else: - x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') - x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + x = x + self.SA( + self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f"att.{i}" + ) + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f"ffn.{i}") x = self.LN(x, w.ln_out) @@ -373,9 +410,10 @@ def run(self, ctx): self.hk = (w.head_k.weight @ x).unsqueeze(0) else: self.hk = torch.cat( - [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0 + ) if self.hk.shape[0] > self.ctx_len: - self.hk = self.hk[-self.ctx_len:, :] + self.hk = self.hk[-self.ctx_len :, :] q = w.head_q.weight @ x diff --git a/RWKV-v4/src/trainer.py b/RWKV-v4/src/trainer.py index 8025cd573..87ec796c9 100644 --- a/RWKV-v4/src/trainer.py +++ b/RWKV-v4/src/trainer.py @@ -3,8 +3,9 @@ ######################################################################################################## import os -NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) -USE_WANDB = (int(os.environ['USE_WANDB']) == 1) + +NUM_GPUS = int(os.environ["RWKV_NUM_GPUS"]) +USE_WANDB = int(os.environ["USE_WANDB"]) == 1 from torch.utils.data.dataloader import DataLoader import torch @@ -18,13 +19,14 @@ logger = logging.getLogger(__name__) torch.backends.cudnn.benchmark = True -if os.environ['RWKV_FLOAT_MODE'] == 'fp32': +if os.environ["RWKV_FLOAT_MODE"] == "fp32": torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False else: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True + class TrainerConfig: batch_size = 64 learning_rate = 4e-4 @@ -34,35 +36,51 @@ class TrainerConfig: warmup_tokens = 0 final_tokens = 0 epoch_save_frequency = 0 - epoch_save_path = 'trained-' + epoch_save_path = "trained-" num_workers = 0 # for DataLoader def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + from src.model import GPT, GPTConfig -class Trainer(LightningLite): +class Trainer(LightningLite): def get_run_name(self): - raw_model = self.model.module if hasattr( - self.model, "module") else self.model + raw_model = self.model.module if hasattr(self.model, "module") else self.model cfg = raw_model.config - run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ - cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) + run_name = ( + str(cfg.vocab_size) + + "-" + + str(cfg.ctx_len) + + "-" + + cfg.model_type + + "-" + + str(cfg.n_layer) + + "-" + + str(cfg.n_embd) + ) return run_name def run(self, m_cfg, train_dataset, test_dataset, config): - self.cuda_id = int(str(self.device).strip('cuda:')) - print('[0]') - model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=m_cfg.model_type, - n_layer=m_cfg.n_layer, n_embd=m_cfg.n_embd)) - print('[1]') + self.cuda_id = int(str(self.device).strip("cuda:")) + print("[0]") + model = GPT( + GPTConfig( + train_dataset.vocab_size, + train_dataset.ctx_len, + model_type=m_cfg.model_type, + n_layer=m_cfg.n_layer, + n_embd=m_cfg.n_embd, + ) + ) + print("[1]") with torch.no_grad(): if m_cfg.LOAD_MODEL: - print('loading', m_cfg.MODEL_NAME) - m2 = torch.load(m_cfg.MODEL_NAME + '.pth', map_location='cpu') + print("loading", m_cfg.MODEL_NAME) + m2 = torch.load(m_cfg.MODEL_NAME + ".pth", map_location="cpu") model.load_state_dict(m2) del m2 model.to(self.device) @@ -74,51 +92,75 @@ def run(self, m_cfg, train_dataset, test_dataset, config): self.avg_loss = -1 self.EPOCH_BEGIN = m_cfg.EPOCH_BEGIN - self.steps = self.EPOCH_BEGIN * (len(self.train_dataset) // (config.batch_size // NUM_GPUS)) + self.steps = self.EPOCH_BEGIN * ( + len(self.train_dataset) // (config.batch_size // NUM_GPUS) + ) if self.cuda_id == 0: log_file = open("mylog.txt", "a") if USE_WANDB: - print('logging to wandb... (comment it if you don\'t have wandb)') - import wandb # comment this if you don't have wandb + print("logging to wandb... (comment it if you don't have wandb)") + import wandb # comment this if you don't have wandb + cfg = model.config for k in config.__dict__: - setattr(cfg, k, config.__dict__[k]) # combine cfg - wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) + setattr(cfg, k, config.__dict__[k]) # combine cfg + wandb.init( + project="RWKV-LM", + name=self.get_run_name() + + "-" + + datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"), + config=cfg, + save_code=False, + ) model, config = self.model, self.config raw_model = model.module if hasattr(self.model, "module") else model optimizer = raw_model.configure_optimizers(config) model, optimizer = self.setup(model, optimizer) - print('[3]') + print("[3]") def run_epoch(split): - is_train = split == 'train' + is_train = split == "train" model.train(is_train) data = self.train_dataset if is_train else self.test_dataset data.idx_begin = self.steps * config.batch_size + 1 data.cuda_id = self.cuda_id - + if config.num_workers > 0: - loader = DataLoader(data, shuffle=False, pin_memory=True, - batch_size=config.batch_size // NUM_GPUS, - num_workers=config.num_workers) + loader = DataLoader( + data, + shuffle=False, + pin_memory=True, + batch_size=config.batch_size // NUM_GPUS, + num_workers=config.num_workers, + ) else: - loader = DataLoader(data, shuffle=False, - batch_size=config.batch_size // NUM_GPUS, - num_workers=config.num_workers) - - pbar = tqdm(enumerate(loader), total=len( - loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) + loader = DataLoader( + data, + shuffle=False, + batch_size=config.batch_size // NUM_GPUS, + num_workers=config.num_workers, + ) + + pbar = ( + tqdm( + enumerate(loader), + total=len(loader), + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) + if is_train + else enumerate(loader) + ) loader = self.setup_dataloaders(loader) gc.collect() torch.cuda.empty_cache() - + for it, (x, y) in pbar: with torch.set_grad_enabled(is_train): - loss = model(x, y) # forward the model + loss = model(x, y) # forward the model - if os.environ['RWKV_DEEPSPEED'] == '0': + if os.environ["RWKV_DEEPSPEED"] == "0": all_loss = [loss.clone()] else: all_loss = [loss.clone() for _ in range(NUM_GPUS)] @@ -133,55 +175,79 @@ def run_epoch(split): optimizer.step() # decay the learning rate based on our progress - self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + self.tokens += ( + y >= 0 + ).sum() # number of tokens processed this step (i.e. label is not -100) lr_final_factor = config.lr_final / config.learning_rate if self.tokens < config.warmup_tokens: # linear warmup - lr_mult = lr_final_factor + \ - (1 - lr_final_factor) * float(self.tokens) / \ - float(config.warmup_tokens) + lr_mult = lr_final_factor + (1 - lr_final_factor) * float( + self.tokens + ) / float(config.warmup_tokens) progress = 0 else: # exponential learning rate decay - progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + progress = float(self.tokens - config.warmup_tokens) / float( + max(1, config.final_tokens - config.warmup_tokens) + ) if progress >= 1: lr_mult = lr_final_factor else: - lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) + lr_mult = math.exp( + math.log(lr_final_factor) * pow(progress, 1) + ) lr = config.learning_rate * lr_mult for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr self.lr = lr self.steps += 1 - + now_loss = 0 for gg in range(NUM_GPUS): now_loss += all_loss[gg].item() - now_loss = now_loss / NUM_GPUS # report progress + now_loss = now_loss / NUM_GPUS # report progress if USE_WANDB and self.cuda_id == 0: - wandb.log({"loss": now_loss}, step = self.steps) + wandb.log({"loss": now_loss}, step=self.steps) if self.avg_loss < 0: self.avg_loss = now_loss else: factor = 1 / (it + 1) - self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor + self.avg_loss = ( + self.avg_loss * (1.0 - factor) + now_loss * factor + ) - pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}") + pbar.set_description( + f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}" + ) self.tokens = 0 # counter used for learning rate decay for epoch in range(99999999): - run_epoch('train') + run_epoch("train") if math.isnan(self.avg_loss): exit(0) if self.cuda_id == 0: - log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n') + log_file.write( + f"{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n" + ) log_file.flush() - - if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): - raw_model = self.model.module if hasattr(self.model, "module") else self.model - torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth') + + if ( + self.config.epoch_save_frequency > 0 + and epoch % self.config.epoch_save_frequency == 0 + ) or (epoch == config.max_epochs - 1): + raw_model = ( + self.model.module + if hasattr(self.model, "module") + else self.model + ) + torch.save( + raw_model.state_dict(), + self.config.epoch_save_path + + str(epoch + 1 + self.EPOCH_BEGIN) + + ".pth", + ) diff --git a/RWKV-v4/src/utils.py b/RWKV-v4/src/utils.py index a73792c03..86fc0b747 100644 --- a/RWKV-v4/src/utils.py +++ b/RWKV-v4/src/utils.py @@ -3,8 +3,9 @@ ######################################################################################################## import os + try: - NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) + NUM_GPUS = int(os.environ["RWKV_NUM_GPUS"]) except: NUM_GPUS = 1 @@ -15,24 +16,25 @@ from torch.nn import functional as F from torch.utils.data import Dataset + class Dataset(Dataset): def __init__(self, data, ctx_len, epoch_length_fixed): self.ctx_len = ctx_len self.epoch_length_fixed = epoch_length_fixed self.data = data - if 'MMapIndexedDataset' in str(type(self.data)): - self.vocab_size = int(os.environ['VOCAB_SIZE']) - print('current vocab size =', self.vocab_size, "(make sure it's correct)") + if "MMapIndexedDataset" in str(type(self.data)): + self.vocab_size = int(os.environ["VOCAB_SIZE"]) + print("current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data._bin_buffer) // 2 - print(f'data has {self.data_size} tokens.') - elif 'numpy' in str(type(self.data)): - self.vocab_size = int(os.environ['VOCAB_SIZE']) - print('current vocab size =', self.vocab_size, "(make sure it's correct)") + print(f"data has {self.data_size} tokens.") + elif "numpy" in str(type(self.data)): + self.vocab_size = int(os.environ["VOCAB_SIZE"]) + print("current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data) - print(f'data has {self.data_size} tokens.') + print(f"data has {self.data_size} tokens.") else: - print('building token list...', end=' ') + print("building token list...", end=" ") unique = sorted(list(set(data))) self.vocab_size = len(unique) # print() @@ -45,10 +47,10 @@ def __init__(self, data, ctx_len, epoch_length_fixed): for u in unique: xxObj[xx] = u xx += 1 - with open('vocab.json', "w", encoding="utf-16") as vocab_file: + with open("vocab.json", "w", encoding="utf-16") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - print('data has %d tokens, %d unique.' % (self.data_size, self.vocab_size)) + print("data has %d tokens, %d unique." % (self.data_size, self.vocab_size)) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} @@ -60,32 +62,34 @@ def __getitem__(self, idx): # we are cheating: pick a random spot in dataset # i = np.random.randint(0, self.data_size - (self.ctx_len + 1)) - if 'MMapIndexedDataset' in str(type(self.data)): + if "MMapIndexedDataset" in str(type(self.data)): dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int) - elif 'numpy' in str(type(self.data)): - dix = self.data[i:i+self.ctx_len+1] + elif "numpy" in str(type(self.data)): + dix = self.data[i : i + self.ctx_len + 1] else: - dix = [self.stoi[s] for s in self.data[i:i+self.ctx_len+1]] - + dix = [self.stoi[s] for s in self.data[i : i + self.ctx_len + 1]] + x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) return x, y -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - if 'list' in str(type(WORD_NAME)): +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): self.charMode = False if WORD_NAME[0] == WORD_NAME[1]: from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) else: from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) else: self.charMode = True - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -96,16 +100,18 @@ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) @@ -113,7 +119,7 @@ def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_ probs = F.softmax(torch.tensor(out), dim=-1) if self.charMode: - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual diff --git a/RWKV-v4/train.py b/RWKV-v4/train.py index f9bb4628e..430a7bf69 100644 --- a/RWKV-v4/train.py +++ b/RWKV-v4/train.py @@ -10,8 +10,11 @@ from src.binidx import MMapIndexedDataset np.set_printoptions(precision=4, suppress=True, linewidth=200) -logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) # if False: # True False ---> Set to False if you don't understand it # print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") @@ -22,10 +25,12 @@ # Step 1: set training data & cfg ######################################################################################################## -EXPRESS_PILE_MODE = False # True: express mode for fine-tuning a pile model // False: usual training +EXPRESS_PILE_MODE = ( + False # True: express mode for fine-tuning a pile model // False: usual training +) -EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023' -EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-169M' +EXPRESS_PILE_MODEL_NAME = "RWKV-4-Pile-169M-20220807-8023" +EXPRESS_PILE_MODEL_TYPE = "RWKV-4-Pile-169M" # EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066' # EXPRESS_PILE_MODEL_TYPE = 'RWKV-4-Pile-430M' # EXPRESS_PILE_MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040' @@ -33,24 +38,24 @@ ######################################################################################################## -datafile = "../data/enwik8" # your data -datafile_encoding = 'utf-8' # 'utf-8' / 'utf-16le' / 'numpy' (for fine-tuning pile models) / 'binidx' (the Megatron-LM 'binidx' format) +datafile = "../data/enwik8" # your data +datafile_encoding = "utf-8" # 'utf-8' / 'utf-16le' / 'numpy' (for fine-tuning pile models) / 'binidx' (the Megatron-LM 'binidx' format) # datafile = 'my-gpt_seq_document' # datafile_encoding = 'binidx' if EXPRESS_PILE_MODE: - datafile = 'train.npy' # use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into .npy - datafile_encoding = 'numpy' + datafile = "train.npy" # use 'prepare-data.py' in https://github.com/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into .npy + datafile_encoding = "numpy" # # set VOCAB_SIZE = 0 (auto-compute) if you are training a char-level LM from scratch # set VOCAB_SIZE = 50277 for fine-tuning pile models # set VOCAB_SIZE = your_vocab_size for 'binidx' data # -os.environ['VOCAB_SIZE'] = '0' +os.environ["VOCAB_SIZE"] = "0" if EXPRESS_PILE_MODE: - os.environ['VOCAB_SIZE'] = '50277' + os.environ["VOCAB_SIZE"] = "50277" # # Currently it's slow to initialize a new model. Hence I suggest this procedure for multi-GPU training: @@ -58,34 +63,40 @@ # 2) set RWKV_NUM_GPUS = '8' (or your #GPU), batch_size = single_gpu_batchsz * RWKV_NUM_GPUS, # EPOCH_BEGIN = 1, LOAD_MODEL = True, and it will load 'trained-1.pth' and continue the training from it # -os.environ['RWKV_NUM_GPUS'] = '1' # num of GPUs to use +os.environ["RWKV_NUM_GPUS"] = "1" # num of GPUs to use # # 'bf16' (fast & stable) # 'fp16' (fast & will overflow after training a large model for very long. can be solved in the future) # 'tf32' (decent speed & stable) # 'fp32' (!!!very slow!!! only for verification) -os.environ['RWKV_FLOAT_MODE'] = 'bf16' +os.environ["RWKV_FLOAT_MODE"] = "bf16" -os.environ['RWKV_DEEPSPEED'] = '1' # Use DeepSpeed? 0 = False, 1 = True +os.environ["RWKV_DEEPSPEED"] = "1" # Use DeepSpeed? 0 = False, 1 = True -if int(os.environ['RWKV_NUM_GPUS']) == 1: # Usually you don't need DeepSpeed for 1 GPU training. - os.environ['RWKV_DEEPSPEED'] = '0' # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it. +if ( + int(os.environ["RWKV_NUM_GPUS"]) == 1 +): # Usually you don't need DeepSpeed for 1 GPU training. + os.environ[ + "RWKV_DEEPSPEED" + ] = "0" # However, sometimes DeepSpeed saves VRAM even for 1 GPU training. So you shall try it. -os.environ['USE_WANDB'] = '0' # wandb logging. 0 = False, 1 = True +os.environ["USE_WANDB"] = "0" # wandb logging. 0 = False, 1 = True ######################################################################################################## # Step 2: set model details ######################################################################################################## -EPOCH_BEGIN = 0 # begins with miniEpoch = EPOCH_BEGIN -LOAD_MODEL = False # shall we load the #EPOCH_BEGIN model and continue the training from it? +EPOCH_BEGIN = 0 # begins with miniEpoch = EPOCH_BEGIN +LOAD_MODEL = ( + False # shall we load the #EPOCH_BEGIN model and continue the training from it? +) n_layer = 6 n_embd = 512 -ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer +ctx_len = 1024 # increase T_MAX in src/model.py if your ctx_len is longer -model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' (sometimes better) +model_type = "RWKV" # 'RWKV' or 'RWKV-ffnPre' (sometimes better) # there is also a RWKV_HEAD_QK_DIM in model.py and model_run.py # set it to 256, then it's using my headQK trick (a tiny attention) to improve loss @@ -93,15 +104,15 @@ if EXPRESS_PILE_MODE: LOAD_MODEL = True - if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M': + if EXPRESS_PILE_MODEL_TYPE == "RWKV-4-Pile-169M": n_layer = 12 n_embd = 768 ctx_len = 1024 - elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-430M': + elif EXPRESS_PILE_MODEL_TYPE == "RWKV-4-Pile-430M": n_layer = 24 n_embd = 1024 ctx_len = 1024 - elif EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-1B5': + elif EXPRESS_PILE_MODEL_TYPE == "RWKV-4-Pile-1B5": n_layer = 24 n_embd = 2048 ctx_len = 1024 @@ -111,8 +122,8 @@ ######################################################################################################## # if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU. -batch_size = 12 * int(os.environ['RWKV_NUM_GPUS']) -assert (batch_size % int(os.environ['RWKV_NUM_GPUS']) == 0) +batch_size = 12 * int(os.environ["RWKV_NUM_GPUS"]) +assert batch_size % int(os.environ["RWKV_NUM_GPUS"]) == 0 # By default we are using exponential LR decay. # Here are my suggestions for training. @@ -121,7 +132,7 @@ # 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run. # 3) Set lr_init = 8e-4, lr_final = 1e-5, betas = (0.9, 0.999). # 4) Set EPOCH_BEGIN & LOAD_MODEL to load the partially-trained model. Continue the training. -# +# # For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4. lr_init = 8e-4 @@ -129,14 +140,16 @@ # the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens) n_epoch = 500 -epoch_length_fixed = (10000 // batch_size) * batch_size # feel free to increase it if you have lots of GPU +epoch_length_fixed = ( + 10000 // batch_size +) * batch_size # feel free to increase it if you have lots of GPU # epoch_save_frequency 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ... epoch_save_frequency = 10 -epoch_save_path = 'trained-' +epoch_save_path = "trained-" if EXPRESS_PILE_MODE: - if EXPRESS_PILE_MODEL_TYPE == 'RWKV-4-Pile-169M': + if EXPRESS_PILE_MODEL_TYPE == "RWKV-4-Pile-169M": lr_init = 2e-5 else: lr_init = 1e-5 @@ -145,18 +158,23 @@ ### misc stuffs ######################################################################################## -if LOAD_MODEL and EPOCH_BEGIN > 0: # we are not saving gradients, so let's have some warmup if we load a model +if ( + LOAD_MODEL and EPOCH_BEGIN > 0 +): # we are not saving gradients, so let's have some warmup if we load a model warmup_tokens = 50 * ctx_len * batch_size // NUM_GPUS else: warmup_tokens = 0 -betas = (0.9, 0.99) # set betas = (0.9, 0.999) if your model has been trained for a while +betas = ( + 0.9, + 0.99, +) # set betas = (0.9, 0.999) if your model has been trained for a while eps = 1e-8 -num_workers = 1 # DataLoader worker. I only tested num_workers = 1 +num_workers = 1 # DataLoader worker. I only tested num_workers = 1 -NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) -os.environ['RWKV_LOAD_MODEL'] = str(LOAD_MODEL) +NUM_GPUS = int(os.environ["RWKV_NUM_GPUS"]) +os.environ["RWKV_LOAD_MODEL"] = str(LOAD_MODEL) MODEL_NAME = epoch_save_path + str(EPOCH_BEGIN) if EXPRESS_PILE_MODE: @@ -164,7 +182,7 @@ MODEL_NAME = EXPRESS_PILE_MODEL_NAME torch.backends.cudnn.benchmark = True -if os.environ['RWKV_FLOAT_MODE'] == 'fp32': +if os.environ["RWKV_FLOAT_MODE"] == "fp32": torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False else: @@ -175,27 +193,63 @@ # Load data ######################################################################################################## -print(f'loading {datafile_encoding} data... ' + datafile) -if datafile_encoding == 'binidx': +print(f"loading {datafile_encoding} data... " + datafile) +if datafile_encoding == "binidx": train_dataset = Dataset(MMapIndexedDataset(datafile), ctx_len, epoch_length_fixed) -elif datafile_encoding == 'numpy': - train_dataset = Dataset(np.load(datafile).astype('int'), ctx_len, epoch_length_fixed) +elif datafile_encoding == "numpy": + train_dataset = Dataset( + np.load(datafile).astype("int"), ctx_len, epoch_length_fixed + ) else: - train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) + train_dataset = Dataset( + open(datafile, "r", encoding=datafile_encoding).read(), + ctx_len, + epoch_length_fixed, + ) ######################################################################################################## # Train model ######################################################################################################## -if __name__ == '__main__': +if __name__ == "__main__": from src.trainer import Trainer, TrainerConfig - print('\nmodel', model_type, os.environ['RWKV_FLOAT_MODE'], 'epoch', n_epoch, 'batchsz', batch_size, 'betas', - betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, '\n') - - tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, - learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, - warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) + print( + "\nmodel", + model_type, + os.environ["RWKV_FLOAT_MODE"], + "epoch", + n_epoch, + "batchsz", + batch_size, + "betas", + betas, + "eps", + eps, + "ctx", + ctx_len, + "layer", + n_layer, + "embd", + n_embd, + "\n", + ) + + tconf = TrainerConfig( + model_type=model_type, + max_epochs=n_epoch, + batch_size=batch_size, + learning_rate=lr_init, + lr_decay=True, + lr_final=lr_final, + betas=betas, + eps=eps, + warmup_tokens=warmup_tokens, + final_tokens=n_epoch * len(train_dataset) * ctx_len, + num_workers=num_workers, + epoch_save_frequency=epoch_save_frequency, + epoch_save_path=epoch_save_path, + ) m_cfg = types.SimpleNamespace() m_cfg.model_type = model_type m_cfg.n_layer = n_layer @@ -204,57 +258,57 @@ m_cfg.LOAD_MODEL = LOAD_MODEL m_cfg.MODEL_NAME = MODEL_NAME - if os.environ['RWKV_DEEPSPEED'] == '0': - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': - trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16) - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision='bf16') - elif '32' in os.environ['RWKV_FLOAT_MODE']: + if os.environ["RWKV_DEEPSPEED"] == "0": + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=16) + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision="bf16") + elif "32" in os.environ["RWKV_FLOAT_MODE"]: trainer = Trainer(devices=NUM_GPUS, accelerator="gpu", precision=32) else: from pytorch_lightning.strategies import DeepSpeedStrategy - + DEEPSPEED_CFG = { - "zero_allow_untested_optimizer":True, - "zero_optimization":{ - "stage":2, - "contiguous_gradients":True, - "overlap_comm":True, - "allgather_partitions":True, - "reduce_scatter":True, - "allgather_bucket_size":200000000, - "reduce_bucket_size":200000000, - "sub_group_size":1000000000000 + "zero_allow_untested_optimizer": True, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": True, + "overlap_comm": True, + "allgather_partitions": True, + "reduce_scatter": True, + "allgather_bucket_size": 200000000, + "reduce_bucket_size": 200000000, + "sub_group_size": 1000000000000, }, - "activation_checkpointing":{ - "partition_activations":False, - "cpu_checkpointing":False, - "contiguous_memory_optimization":False, - "synchronize_checkpoint_boundary":False + "activation_checkpointing": { + "partition_activations": False, + "cpu_checkpointing": False, + "contiguous_memory_optimization": False, + "synchronize_checkpoint_boundary": False, }, - "aio":{ - "block_size":1048576, - "queue_depth":8, - "single_submit":False, - "overlap_events":True, - "thread_count":1 + "aio": { + "block_size": 1048576, + "queue_depth": 8, + "single_submit": False, + "overlap_events": True, + "thread_count": 1, }, "gradient_clipping": 1.0, "gradient_accumulation_steps": 1, } if NUM_GPUS == 1: - DEEPSPEED_CFG['zero_optimization'] = { - "stage":1, # saves some VRAM - "contiguous_gradients":False, - "overlap_comm":False, - "allgather_partitions":False, - "reduce_scatter":False, - "allgather_bucket_size":200000000, - "reduce_bucket_size":200000000, - "sub_group_size":1000000000000 + DEEPSPEED_CFG["zero_optimization"] = { + "stage": 1, # saves some VRAM + "contiguous_gradients": False, + "overlap_comm": False, + "allgather_partitions": False, + "reduce_scatter": False, + "allgather_bucket_size": 200000000, + "reduce_bucket_size": 200000000, + "sub_group_size": 1000000000000, } - if os.environ['RWKV_FLOAT_MODE'] == 'fp16': + if os.environ["RWKV_FLOAT_MODE"] == "fp16": DEEPSPEED_CFG["fp16"] = { "fp16": True, "enabled": True, @@ -262,19 +316,32 @@ "initial_scale_power": 12, "loss_scale_window": 1000, "hysteresis": 2, - "min_loss_scale": 1 + "min_loss_scale": 1, } - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=16) - - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - DEEPSPEED_CFG["bf16"] = { - "enabled": True - } - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision='bf16') - - elif '32' in os.environ['RWKV_FLOAT_MODE']: - trainer = Trainer(strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), devices=NUM_GPUS, accelerator="gpu", precision=32) + trainer = Trainer( + strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), + devices=NUM_GPUS, + accelerator="gpu", + precision=16, + ) + + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + DEEPSPEED_CFG["bf16"] = {"enabled": True} + trainer = Trainer( + strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), + devices=NUM_GPUS, + accelerator="gpu", + precision="bf16", + ) + + elif "32" in os.environ["RWKV_FLOAT_MODE"]: + trainer = Trainer( + strategy=DeepSpeedStrategy(config=DEEPSPEED_CFG), + devices=NUM_GPUS, + accelerator="gpu", + precision=32, + ) print(trainer._strategy.config) - + trainer.run(m_cfg, train_dataset, None, tconf) diff --git a/RWKV-v4/verify.py b/RWKV-v4/verify.py index 616b3877e..e84ede28a 100644 --- a/RWKV-v4/verify.py +++ b/RWKV-v4/verify.py @@ -5,86 +5,103 @@ # this is for verifying the results of different models and make sure they agree with each other import numpy as np + np.set_printoptions(precision=4, suppress=True, linewidth=200) import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) -os.environ['RWKV_RUN_DEVICE'] = 'cuda' -RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] +os.environ[ + "RWKV_FLOAT_MODE" +] = "bf16" # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) +os.environ["RWKV_RUN_DEVICE"] = "cuda" +RUN_DEVICE = os.environ["RWKV_RUN_DEVICE"] import torch from src.model_run import RWKV_RNN, RWKV_GPT from src.model import GPT, GPTConfig -TOKEN_MODE = 'pile' # char / pile +TOKEN_MODE = "pile" # char / pile -if TOKEN_MODE == 'char': - MODEL_NAME = 'trained-1' - WORD_NAME = 'vocab' # the .json vocab (generated by train.py) +if TOKEN_MODE == "char": + MODEL_NAME = "trained-1" + WORD_NAME = "vocab" # the .json vocab (generated by train.py) ctx_len = 1024 n_layer = 6 n_embd = 512 - UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity -elif TOKEN_MODE == 'pile': - WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] - MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023' + UNKNOWN_CHAR = " " # here we just set it to [space] for simplicity +elif TOKEN_MODE == "pile": + WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"] + MODEL_NAME = "RWKV-4-Pile-169M-20220807-8023" ctx_len = 1024 n_layer = 12 n_embd = 768 UNKNOWN_CHAR = None -model_type = 'RWKV' +model_type = "RWKV" from src.utils import TOKENIZER + tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) -if TOKEN_MODE == 'pile': +if TOKEN_MODE == "pile": tokenizer.vocab_size = 50277 ######################################################################################################## -model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() - -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': +model_train = GPT( + GPTConfig( + tokenizer.vocab_size, + ctx_len, + model_type=model_type, + n_layer=n_layer, + n_embd=n_embd, + ) +).cuda() + +if os.environ["RWKV_FLOAT_MODE"] == "fp16": model_train = model_train.half() -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': +elif os.environ["RWKV_FLOAT_MODE"] == "bf16": model_train = model_train.bfloat16() -print('loading ' + MODEL_NAME) -m2 = torch.load(MODEL_NAME + '.pth', map_location=RUN_DEVICE) +print("loading " + MODEL_NAME) +m2 = torch.load(MODEL_NAME + ".pth", map_location=RUN_DEVICE) model_train.load_state_dict(m2) model_rnn = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) -model_gpt = RWKV_GPT(MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda() +model_gpt = RWKV_GPT( + MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len +).cuda() ######################################################################################################## # context = '\nIn a' -context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." -if TOKEN_MODE == 'char': +if TOKEN_MODE == "char": ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] -elif TOKEN_MODE == 'pile': +elif TOKEN_MODE == "pile": ctx = tokenizer.tokenizer.encode(context) -print(f'input len {len(ctx)} data {ctx}') +print(f"input len {len(ctx)} data {ctx}") ######################################################################################################## -print('\nRWKV-GPT output') +print("\nRWKV-GPT output") out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy() print(out) -print('\nRWKV-RNN output') +print("\nRWKV-RNN output") model_rnn.clear() src_len = len(ctx) for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] out = model_rnn.run(x) if i < 3 or i >= src_len - 3: print(torch.tensor(out).detach().cpu().numpy()) if i == 2: - print('...') + print("...") -print('\nRWKV-train output') -out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy() -print(out, '\n') +print("\nRWKV-train output") +out = ( + model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy() +) +print(out, "\n") diff --git a/RWKV-v4neo/chat.py b/RWKV-v4neo/chat.py index d214ba281..e563e2f1d 100644 --- a/RWKV-v4neo/chat.py +++ b/RWKV-v4neo/chat.py @@ -2,12 +2,13 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -print('Loading...') +print("Loading...") from src.model_run import RWKV_RNN import numpy as np import os, copy, types, gc, sys import torch from src.utils import TOKENIZER + try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: @@ -17,7 +18,7 @@ torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) -CHAT_LANG = 'English' # English Chinese +CHAT_LANG = "English" # English Chinese WORD_NAME = [ "20B_tokenizer.json", @@ -28,14 +29,16 @@ args = types.SimpleNamespace() args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +args.FLOAT_MODE = ( + "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +) args.vocab_size = 50277 args.head_qk = 0 args.pre_ffn = 0 args.grad_cp = 0 args.my_pos_emb = 0 -args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170' +args.MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170" args.n_layer = 40 args.n_embd = 5120 args.ctx_len = 1024 @@ -50,7 +53,7 @@ # args.n_embd = 2560 # args.ctx_len = 1024 -if CHAT_LANG == 'English': +if CHAT_LANG == "English": user = "User" bot = "Bot" interface = ":" @@ -58,7 +61,7 @@ # The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. # The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. - init_prompt = f''' + init_prompt = f""" The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. {user}{interface} french revolution what year @@ -81,8 +84,8 @@ {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. -''' - HELP_MSG = '''Commands: +""" + HELP_MSG = """Commands: say something --> chat with bot. use \\n for new line. +alt --> alternate chat reply +reset --> reset chat @@ -94,9 +97,9 @@ Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. -''' -elif CHAT_LANG == 'Chinese': - args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293' +""" +elif CHAT_LANG == "Chinese": + args.MODEL_NAME = "/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293" args.n_layer = 32 args.n_embd = 4096 args.ctx_len = 1024 @@ -105,7 +108,7 @@ bot = "A" interface = ":" - init_prompt = ''' + init_prompt = """ Q: 企鹅会飞吗? A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 @@ -114,8 +117,8 @@ A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。 -''' - HELP_MSG = '''指令: +""" + HELP_MSG = """指令: 直接输入内容 --> 和机器人聊天,用\\n代表换行 +alt --> 让机器人换个回答 +reset --> 重置对话 @@ -126,14 +129,14 @@ +retry --> 换个 +gen / +qa 的回答 现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 -''' +""" # Load Model os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE MODEL_NAME = args.MODEL_NAME -print(f'loading... {MODEL_NAME}') +print(f"loading... {MODEL_NAME}") model = RWKV_RNN(args) model_tokens = [] @@ -142,15 +145,18 @@ ######################################################################################################## -def run_rnn(tokens, newline_adj = 0): + +def run_rnn(tokens, newline_adj=0): global model_tokens, current_state for i in range(len(tokens)): model_tokens += [int(tokens[i])] if i == len(tokens) - 1: out, current_state = model.forward(model_tokens, current_state) else: - current_state = model.forward(model_tokens, current_state, preprocess_only = True) - + current_state = model.forward( + model_tokens, current_state, preprocess_only=True + ) + # print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]') out[0] = -999999999 # disable <|endoftext|> @@ -159,60 +165,67 @@ def run_rnn(tokens, newline_adj = 0): # out[15] += newline_adj / 2 # '.' return out + all_state = {} + + def save_all_stat(srv, name, last_out): - n = f'{name}_{srv}' + n = f"{name}_{srv}" all_state[n] = {} - all_state[n]['out'] = last_out - all_state[n]['rnn'] = copy.deepcopy(current_state) - all_state[n]['token'] = copy.deepcopy(model_tokens) + all_state[n]["out"] = last_out + all_state[n]["rnn"] = copy.deepcopy(current_state) + all_state[n]["token"] = copy.deepcopy(model_tokens) + def load_all_stat(srv, name): global model_tokens, current_state - n = f'{name}_{srv}' - current_state = copy.deepcopy(all_state[n]['rnn']) - model_tokens = copy.deepcopy(all_state[n]['token']) - return all_state[n]['out'] + n = f"{name}_{srv}" + current_state = copy.deepcopy(all_state[n]["rnn"]) + model_tokens = copy.deepcopy(all_state[n]["token"]) + return all_state[n]["out"] + ######################################################################################################## # Run inference -print(f'\nRun prompt...') +print(f"\nRun prompt...") out = run_rnn(tokenizer.tokenizer.encode(init_prompt)) gc.collect() torch.cuda.empty_cache() -save_all_stat('', 'chat_init', out) +save_all_stat("", "chat_init", out) -srv_list = ['dummy_server'] +srv_list = ["dummy_server"] for s in srv_list: - save_all_stat(s, 'chat', out) + save_all_stat(s, "chat", out) + +print(f"### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n") -print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') def reply_msg(msg): - print(f'{bot}{interface} {msg}\n') + print(f"{bot}{interface} {msg}\n") + def on_message(message): global model_tokens, current_state - srv = 'dummy_server' + srv = "dummy_server" - msg = message.replace('\\n','\n').strip() + msg = message.replace("\\n", "\n").strip() if len(msg) > 1000: - reply_msg('your message is too long (max 1000 tokens)') + reply_msg("your message is too long (max 1000 tokens)") return x_temp = 1.0 x_top_p = 0.85 - if ("-temp=" in msg): + if "-temp=" in msg: x_temp = float(msg.split("-temp=")[1].split(" ")[0]) - msg = msg.replace("-temp="+f'{x_temp:g}', "") + msg = msg.replace("-temp=" + f"{x_temp:g}", "") # print(f"temp: {x_temp}") - if ("-top_p=" in msg): + if "-top_p=" in msg: x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) - msg = msg.replace("-top_p="+f'{x_top_p:g}', "") + msg = msg.replace("-top_p=" + f"{x_top_p:g}", "") # print(f"top_p: {x_top_p}") if x_temp <= 0.2: x_temp = 0.2 @@ -220,31 +233,36 @@ def on_message(message): x_temp = 5 if x_top_p <= 0: x_top_p = 0 - - if msg == '+reset': - out = load_all_stat('', 'chat_init') - save_all_stat(srv, 'chat', out) + + if msg == "+reset": + out = load_all_stat("", "chat_init") + save_all_stat(srv, "chat", out) reply_msg("Chat reset.") return - elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry': + elif ( + msg[:5].lower() == "+gen " + or msg[:4].lower() == "+qa " + or msg.lower() == "+more" + or msg.lower() == "+retry" + ): - if msg[:5].lower() == '+gen ': - new = '\n' + msg[5:].strip() + if msg[:5].lower() == "+gen ": + new = "\n" + msg[5:].strip() # print(f'### prompt ###\n[{new}]') current_state = None out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) + save_all_stat(srv, "gen_0", out) - elif msg[:4].lower() == '+qa ': - out = load_all_stat('', 'chat_init') + elif msg[:4].lower() == "+qa ": + out = load_all_stat("", "chat_init") real_msg = msg[4:].strip() new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" # print(f'### qa ###\n[{new}]') - + out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) + save_all_stat(srv, "gen_0", out) # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:" # print(f'### prompt ###\n[{new}]') @@ -252,16 +270,16 @@ def on_message(message): # out = run_rnn(tokenizer.tokenizer.encode(new)) # save_all_stat(srv, 'gen_0', out) - elif msg.lower() == '+more': + elif msg.lower() == "+more": try: - out = load_all_stat(srv, 'gen_1') - save_all_stat(srv, 'gen_0', out) + out = load_all_stat(srv, "gen_1") + save_all_stat(srv, "gen_0", out) except: return - elif msg.lower() == '+retry': + elif msg.lower() == "+retry": try: - out = load_all_stat(srv, 'gen_0') + out = load_all_stat(srv, "gen_0") except: return @@ -276,37 +294,37 @@ def on_message(message): top_p_usual=x_top_p, top_p_newline=x_top_p, ) - if msg[:4].lower() == '+qa ': + if msg[:4].lower() == "+qa ": out = run_rnn([token], newline_adj=-1) else: out = run_rnn([token]) - + xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) + if "\ufffd" not in xxx: + print(xxx, end="", flush=True) out_last = begin + i + 1 - print('\n') + print("\n") # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # print(f'### send ###\n[{send_msg}]') # reply_msg(send_msg) - save_all_stat(srv, 'gen_1', out) + save_all_stat(srv, "gen_1", out) else: - if msg.lower() == '+alt': + if msg.lower() == "+alt": try: - out = load_all_stat(srv, 'chat_pre') + out = load_all_stat(srv, "chat_pre") except: return else: - out = load_all_stat(srv, 'chat') + out = load_all_stat(srv, "chat") new = f"{user}{interface} {msg}\n\n{bot}{interface}" # print(f'### add ###\n[{new}]') out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999) - save_all_stat(srv, 'chat_pre', out) + save_all_stat(srv, "chat_pre", out) begin = len(model_tokens) out_last = begin - print(f'{bot}{interface}', end='', flush=True) + print(f"{bot}{interface}", end="", flush=True) for i in range(999): if i <= 0: newline_adj = -999999999 @@ -315,7 +333,7 @@ def on_message(message): elif i <= 130: newline_adj = 0 else: - newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION + newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION token = tokenizer.sample_logits( out, model_tokens, @@ -327,15 +345,15 @@ def on_message(message): out = run_rnn([token], newline_adj=newline_adj) xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) + if "\ufffd" not in xxx: + print(xxx, end="", flush=True) out_last = begin + i + 1 - + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]) - if '\n\n' in send_msg: + if "\n\n" in send_msg: send_msg = send_msg.strip() break - + # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! # send_msg = send_msg[:-len(f'{user}{interface}')].strip() @@ -349,13 +367,14 @@ def on_message(message): # print(f'### send ###\n[{send_msg}]') # reply_msg(send_msg) - save_all_stat(srv, 'chat', out) + save_all_stat(srv, "chat", out) + print(HELP_MSG) while True: - msg = input(f'{user}{interface} ') + msg = input(f"{user}{interface} ") if len(msg.strip()) > 0: on_message(msg) else: - print('Erorr: please say something') + print("Erorr: please say something") diff --git a/RWKV-v4neo/img_demoAE.py b/RWKV-v4neo/img_demoAE.py index ab0d4edd6..43c0c3cf3 100644 --- a/RWKV-v4neo/img_demoAE.py +++ b/RWKV-v4neo/img_demoAE.py @@ -9,55 +9,58 @@ from torch.nn import functional as F import torchvision as vision import torchvision.transforms as transforms + np.set_printoptions(precision=4, suppress=True, linewidth=200) -print(f'loading...') +print(f"loading...") ######################################################################################################## -model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201' -input_img = 'test/img_ae_test/test0.png' +model_prefix = "test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201" +input_img = "test/img_ae_test/test0.png" ######################################################################################################## + class ToBinary(torch.autograd.Function): @staticmethod def forward(ctx, x): - return torch.floor(x + 0.5) # no need for noise when we have plenty of data + return torch.floor(x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): - return grad_output.clone() # pass-through + return grad_output.clone() # pass-through + class R_ENCODER(nn.Module): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.Bxx = nn.BatchNorm2d(dd*64) + self.Bxx = nn.BatchNorm2d(dd * 64) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*4) - self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 4) + self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) - self.B20 = nn.BatchNorm2d(dd*64) - self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B20 = nn.BatchNorm2d(dd * 64) + self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) - self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd * 64, args.my_img_bit, kernel_size=3, padding=1) def forward(self, img): ACT = F.mish @@ -81,30 +84,31 @@ def forward(self, img): x = self.COUT(x + xx) return torch.sigmoid(x) + class R_DECODER(nn.Module): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) - - self.B00 = nn.BatchNorm2d(dd*64) - self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(dd*4) - self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.CIN = nn.Conv2d(args.my_img_bit, dd * 64, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd * 64) + self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd * 4) + self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) @@ -128,30 +132,33 @@ def forward(self, code): x = x + self.Cx1(ACT(self.Cx0(x))) x = self.COUT(x) - + return torch.sigmoid(x) + ######################################################################################################## -print(f'building model...') +print(f"building model...") args = types.SimpleNamespace() args.my_img_bit = 13 encoder = R_ENCODER(args).eval().cuda() decoder = R_DECODER(args).eval().cuda() -zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long() +zpow = torch.tensor([2**i for i in range(0, 13)]).reshape(13, 1, 1).cuda().long() -encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth')) -decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth')) +encoder.load_state_dict(torch.load(f"{model_prefix}-E.pth")) +decoder.load_state_dict(torch.load(f"{model_prefix}-D.pth")) ######################################################################################################## -print(f'test image...') -img_transform = transforms.Compose([ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Resize((224, 224)) -]) +print(f"test image...") +img_transform = transforms.Compose( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Resize((224, 224)), + ] +) with torch.no_grad(): img = img_transform(Image.open(input_img)).unsqueeze(0).cuda() @@ -159,7 +166,7 @@ def forward(self, code): z = ToBinary.apply(z) zz = torch.sum(z.squeeze().long() * zpow, dim=0) - print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n') - + print(f"Code shape = {zz.shape}\n{zz.cpu().numpy()}\n") + out = decoder(z) vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg") diff --git a/RWKV-v4neo/run.py b/RWKV-v4neo/run.py index f13e97f08..eb7109cb6 100644 --- a/RWKV-v4neo/run.py +++ b/RWKV-v4neo/run.py @@ -6,6 +6,7 @@ import math, os, sys, types, time, gc import torch from src.utils import TOKENIZER + try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: @@ -20,12 +21,14 @@ # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible) ######################################################################################################## -args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) -args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) +args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) +args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) # if args.RUN_DEVICE == "cuda": # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output -os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! +os.environ[ + "RWKV_JIT_ON" +] = "1" # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! TOKEN_MODE = "pile" WORD_NAME = [ @@ -58,7 +61,7 @@ # n_embd = 2560 # ctx_len = 1024 -MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' +MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047" n_layer = 32 n_embd = 4096 ctx_len = 1024 @@ -129,12 +132,12 @@ ######################################################################################################## -print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...') +print(f"\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...") from src.model_run import RWKV_RNN model = RWKV_RNN(args) -print(f'\nOptimizing speed...') +print(f"\nOptimizing speed...") out, _ = model.forward([187], None) # print(out) gc.collect() @@ -142,10 +145,10 @@ # input(0) -print(f'\nLoading tokenizer {WORD_NAME}...') +print(f"\nLoading tokenizer {WORD_NAME}...") tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) if TOKEN_MODE == "pile": - assert tokenizer.tokenizer.decode([187]) == '\n' + assert tokenizer.tokenizer.decode([187]) == "\n" ######################################################################################################## @@ -165,6 +168,7 @@ time_slot = {} time_ref = time.time_ns() + def record_time(name): if name not in time_slot: time_slot[name] = 1e20 @@ -172,13 +176,14 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt + init_state = None init_out = None state = None out = None for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): - print(("-" * 50) + '\n' + context, end="") + print(("-" * 50) + "\n" + context, end="") time_ref = time.time_ns() ctx = src_ctx.copy() @@ -193,7 +198,7 @@ def record_time(name): gc.collect() torch.cuda.empty_cache() - record_time('preprocess') + record_time("preprocess") out_last = src_len for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): x = ctx[: i + 1] @@ -205,7 +210,14 @@ def record_time(name): else: out, state = model.forward(x, state) if DEBUG_DEBUG: - print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) + print( + "model", + np.array(x), + "==>", + np.array(out), + np.max(out.cpu().numpy()), + np.min(out.cpu().numpy()), + ) if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> @@ -224,14 +236,15 @@ def record_time(name): print(char, end="", flush=True) else: char = tokenizer.tokenizer.decode(ctx[out_last:]) - if '\ufffd' not in char: # is valid utf8 string? + if "\ufffd" not in char: # is valid utf8 string? print(char, end="", flush=True) - out_last = i+1 + out_last = i + 1 - record_time('total') + record_time("total") # print(f'\n\n{time_slot}\n\n') print( - f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' + f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", + end="", ) -print(("-" * 50) + '\n') +print(("-" * 50) + "\n") diff --git a/RWKV-v4neo/src/binidx.py b/RWKV-v4neo/src/binidx.py index 369081ad4..8d5b40bfe 100644 --- a/RWKV-v4neo/src/binidx.py +++ b/RWKV-v4neo/src/binidx.py @@ -7,6 +7,7 @@ from functools import lru_cache from itertools import accumulate + def print_rank_0(*message): pass # """If distributed is initialized print only on rank 0.""" @@ -16,12 +17,14 @@ def print_rank_0(*message): # else: # print(*message, flush=True) + def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass + dtypes = { 1: np.uint8, 2: np.int8, @@ -33,18 +36,22 @@ def _warmup_mmap_file(path): 8: np.uint16, } + def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) + def index_file_path(prefix_path): return prefix_path + ".idx" + def data_file_path(prefix_path): return prefix_path + ".bin" + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" @@ -100,7 +107,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._file.close() return _Writer() - + def __init__(self, path, skip_warmup=False): with open(path, "rb") as stream: magic_test = stream.read(9) @@ -217,8 +224,7 @@ def __getitem__(self, idx): elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: - raise ValueError( - "Slices into indexed_dataset must be contiguous") + raise ValueError("Slices into indexed_dataset must be contiguous") ptr = self._index._pointers[start] sizes = self._index._sizes[idx] offsets = list(accumulate(sizes)) diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 526158642..1c020bae6 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -17,15 +17,24 @@ def __init__(self, args): if args.data_type == "binidx": self.vocab_size = args.vocab_size - rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) if args.my_pile_version == 1: self.data = MMapIndexedDataset(args.data_file) - self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size + self.data_size = ( + len(self.data._bin_buffer) // self.data._index._dtype_size + ) rank_zero_info(f"Data has {self.data_size} tokens.") else: - data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n') - data_list = [i.strip().split(' ') for i in data_list] + data_list = ( + open(args.data_file, "r", encoding="utf-8") + .read() + .strip() + .split("\n") + ) + data_list = [i.strip().split(" ") for i in data_list] self.data = [] self.data_size = int(data_list[-1][-1]) rank_zero_info(f"Data has {self.data_size} chunks.") @@ -37,29 +46,46 @@ def __init__(self, args): # rank_zero_info(self.data) if args.my_qa_mask > 0: - self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') - self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size + self.data_pile = MMapIndexedDataset( + "/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document" + ) + self.data_pile_size = ( + len(self.data_pile._bin_buffer) // self.data._index._dtype_size + ) if args.my_pile_stage > 0: # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 - rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") + rank_zero_info( + f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########" + ) dataset_slot = self.data_size // args.ctx_len if args.my_pile_stage != 4: assert MaybeIsPrime(args.magic_prime) assert args.magic_prime % 3 == 2 - assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 + assert ( + args.magic_prime / dataset_slot > 0.99 + and args.magic_prime / dataset_slot <= 1 + ) elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = len(self.data) rank_zero_info(f"Data has {self.data_size} tokens.") elif args.data_type == "uint16": - self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) + self.data = ( + np.fromfile(args.data_file, dtype=np.uint16) + .astype("int32") + .reshape(-1, args.my_sample_len) + ) self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = self.data.shape[0] rank_zero_info(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": @@ -90,10 +116,14 @@ def __init__(self, args): for u in unique: xxObj[xx] = u xx += 1 - with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: + with open( + f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le" + ) as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") + rank_zero_info( + f"Data has {self.data_size} tokens, {self.vocab_size} vocab size." + ) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} @@ -108,36 +138,53 @@ def __getitem__(self, idx): # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") if args.data_type == "wds_img": + def init_wds(self, bias=0): def identity(x): - return x + return x + import webdataset as wds import torchvision.transforms as transforms + # img_transform = transforms.Compose( # [transforms.CenterCrop(256)] # ) - img_transform = transforms.Compose([ - transforms.CenterCrop(512), - transforms.Resize((args.my_img_size)) - ]) - self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + img_transform = transforms.Compose( + [transforms.CenterCrop(512), transforms.Resize((args.my_img_size))] + ) + self.data_raw = ( + wds.WebDataset(args.data_file, resampled=True) + .shuffle( + 10000, + initial=1000, + rng=random.Random(epoch * 100000 + rank + bias * 1e9), + ) + .decode("torchrgb") + .to_tuple("jpg", "json", "txt") + .map_tuple(img_transform, identity, identity) + ) for pp in self.data_raw.pipeline: - if 'Resampled' in str(pp): + if "Resampled" in str(pp): pp.deterministic = True + def worker_seed(): - return rank*100000+epoch+bias*1e9 + return rank * 100000 + epoch + bias * 1e9 + pp.worker_seed = worker_seed self.data = iter(self.data_raw) # print(f"WebDataset loaded for rank {rank} epoch {epoch}") + if self.data == None: init_wds(self) trial = 0 while trial < 10: try: - dd = next(self.data) # jpg, json, txt + dd = next(self.data) # jpg, json, txt break except: - print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') + print( + f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]" + ) self.error_count += 1 init_wds(self, self.error_count) trial += 1 @@ -148,7 +195,7 @@ def worker_seed(): return dd[0], dd[2] else: if args.data_type == "uint16": - i = np.random.randint(0, self.data_size-1) + i = np.random.randint(0, self.data_size - 1) dix = self.data[i] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) @@ -201,8 +248,12 @@ def worker_seed(): for j in range(len(data)): if i < data[j][0]: ii = i - i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1] - dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int) + i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1] + dix = ( + data[j][2] + .get(idx=0, offset=i, length=req_len) + .astype(int) + ) # print(ii, j, i) break elif args.data_type == "numpy": @@ -218,7 +269,12 @@ def worker_seed(): z_sum = 0 isGood = False for i in range(3, ctx_len): - if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: + if ( + dix[i] == 27 + and dix[i - 1] == 34 + and dix[i - 2] == 187 + and dix[i - 3] == 187 + ): isGood = True if dix[i] == 0: isGood = False @@ -228,7 +284,9 @@ def worker_seed(): if z_sum == 0: z = [1] * ctx_len i = np.random.randint(0, self.data_pile_size - req_len) - dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) + dix = self.data_pile.get( + idx=0, offset=i, length=req_len + ).astype(int) z = torch.tensor(z, dtype=torch.bfloat16) x = torch.tensor(dix[:-1], dtype=torch.long) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index b79f96d26..0914c160e 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -4,6 +4,7 @@ import os, math, gc, importlib import torch + # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) import torch.nn as nn @@ -11,16 +12,18 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy -if importlib.util.find_spec('deepspeed'): + +if importlib.util.find_spec("deepspeed"): import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam try: - print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) + print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) except: - os.environ["RWKV_MY_TESTING"] = '' + os.environ["RWKV_MY_TESTING"] = "" + def __nop(ob): return ob @@ -43,7 +46,23 @@ def __nop(ob): from torch.utils.cpp_extension import load if os.environ["RWKV_FLOAT_MODE"] == "bf16": - wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}_bf16", + sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], + verbose=True, + extra_cuda_cflags=[ + "-t 4", + "-std=c++17", + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -56,10 +75,16 @@ def forward(ctx, B, T, C, w, u, k, v): u = u.contiguous() k = k.contiguous() v = v.contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + y = torch.empty( + (B, T, C), + device=w.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) return y + @staticmethod def backward(ctx, gy): B = ctx.B @@ -68,16 +93,51 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gw = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gu = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gk = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gv = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv) + else: - wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}", + sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -96,7 +156,9 @@ def forward(ctx, B, T, C, w, u, k, v): u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) + y = torch.empty( + (B, T, C), device=w.device, memory_format=torch.contiguous_format + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -105,6 +167,7 @@ def forward(ctx, B, T, C, w, u, k, v): return y.half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() + @staticmethod def backward(ctx, gy): B = ctx.B @@ -113,14 +176,26 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gw = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gu = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gk = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) + gv = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) if "32" in os.environ["RWKV_FLOAT_MODE"]: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv + ) else: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv + ) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -128,7 +203,15 @@ def backward(ctx, gy): elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + return ( + None, + None, + None, + gw.bfloat16(), + gu.bfloat16(), + gk.bfloat16(), + gv.bfloat16(), + ) def RUN_CUDA(B, T, C, w, u, k, v): @@ -154,21 +237,27 @@ def __init__(self, args, layer_id): ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd - + # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): - decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 - self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) + self.time_first = nn.Parameter( + torch.ones(args.dim_att) * math.log(0.3) + zigzag + ) # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_v = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) @@ -177,8 +266,10 @@ def __init__(self, args, layer_id): self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) - if 'a' in os.environ["RWKV_MY_TESTING"]: - self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + if "a" in os.environ["RWKV_MY_TESTING"]: + self.register_buffer( + "att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) d_qkv = args.n_embd // 16 self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) @@ -187,12 +278,17 @@ def __init__(self, args, layer_id): with torch.no_grad(): self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_vv = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + + if "a" not in os.environ["RWKV_MY_TESTING"]: - if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -205,21 +301,26 @@ def jit_func(self, x): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) return self.output(rwkv) - if 'a' in os.environ["RWKV_MY_TESTING"]: + if "a" in os.environ["RWKV_MY_TESTING"]: + @MyFunction def QKV(self, q, k, v): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.att_mask == 0, float('-inf')) - att = F.softmax(att, dim = -1) + att = att.masked_fill(self.att_mask == 0, float("-inf")) + att = F.softmax(att, dim=-1) x = att @ v return x @MyFunction def jit_funcQKV(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -238,12 +339,16 @@ def jit_funcQKV(self, x): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv + ######################################################################################################## + class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -258,7 +363,7 @@ def __init__(self, args, layer_id): ddd[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - + self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @@ -273,6 +378,7 @@ def forward(self, x): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv + class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -302,6 +408,7 @@ def forward(self, x): b = self.bb(xb) return self.value(a * F.mish(b)) + ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## @@ -319,25 +426,31 @@ def __init__(self, args, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: - self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) - self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) + self.pos_emb_x = nn.Parameter( + torch.zeros((1, args.my_pos_emb, args.n_embd)) + ) + self.pos_emb_y = nn.Parameter( + torch.zeros((args.my_pos_emb, 1, args.n_embd)) + ) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix(args, layer_id) - if 'g' in os.environ["RWKV_MY_TESTING"]: + if "g" in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) - + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) - self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def forward(self, x, x_emb=None): args = self.args @@ -345,7 +458,7 @@ def forward(self, x, x_emb=None): if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: - pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] x = x + pos_emb if self.layer_id == 0 and args.pre_ffn > 0: @@ -385,13 +498,13 @@ class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args - if not hasattr(args, 'dim_att'): + if not hasattr(args, "dim_att"): args.dim_att = args.n_embd - if not hasattr(args, 'dim_ffn'): + if not hasattr(args, "dim_ffn"): args.dim_ffn = args.n_embd * 4 - if not hasattr(args, 'tiny_att_layer'): + if not hasattr(args, "tiny_att_layer"): args.tiny_att_layer = -1 - if not hasattr(args, 'tiny_att_dim'): + if not hasattr(args, "tiny_att_dim"): args.tiny_att_dim = -1 self.emb = nn.Embedding(args.vocab_size, args.n_embd) @@ -404,7 +517,9 @@ def __init__(self, args): if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) - self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def configure_optimizers(self): args = self.args @@ -436,24 +551,69 @@ def configure_optimizers(self): param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 2e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 3e-3 / args.lr_init}, ] else: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 2.0, + }, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 3.0, + }, ] else: optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + { + "params": [p for n, p in self.named_parameters()], + "weight_decay": 0.0, + }, ] if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property @@ -521,10 +681,14 @@ def training_step(self, batch, batch_idx): logits = self(idx) if sum_mask == mask.shape[0]: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) # print('rank', self.global_rank, 'loss', loss.item()) else: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" + ) # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask @@ -564,7 +728,14 @@ def generate_init_weight(self): gain = 1.0 scale = 1.0 - if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: + if ( + "ln_" in n + or ".ln" in n + or "time_" in n + or "_mask" in n + or "pos_emb" in n + or ".mask." in n + ): m[n] = p else: if n == "emb.weight": @@ -572,7 +743,19 @@ def generate_init_weight(self): else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: + for kk in [ + ".att.key.", + ".att.receptance.", + ".att.output.", + ".att.key.", + ".ffn.value.", + ".ffn.receptance.", + ".ffnPre.value.", + ".ffnPre.receptance.", + "head_q.", + ".oo.", + ".rr.", + ]: if kk in n: scale = 0 if n == "head.weight": @@ -582,7 +765,9 @@ def generate_init_weight(self): if "head_q." in n: scale = 0 - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + print( + f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" + ) if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") diff --git a/RWKV-v4neo/src/model_img.py b/RWKV-v4neo/src/model_img.py index 24337236b..3a9bceb4e 100644 --- a/RWKV-v4neo/src/model_img.py +++ b/RWKV-v4neo/src/model_img.py @@ -13,10 +13,14 @@ from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + # from pytorch_msssim import MS_SSIM + def __nop(ob): return ob + + MyModule = torch.jit.ScriptModule # MyFunction = __nop MyFunction = torch.jit.script_method @@ -24,6 +28,7 @@ def __nop(ob): import clip from transformers import CLIPModel + class L2pooling(nn.Module): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): super(L2pooling, self).__init__() @@ -149,55 +154,57 @@ def forward(self, x, y, require_grad=False, batch_average=False): class ToBinary(torch.autograd.Function): @staticmethod - def forward(ctx, x):#, noise_scale): + def forward(ctx, x): # , noise_scale): # if noise_scale > 0: # noise_min = 0.5 - noise_scale / 2 # noise_max = 0.5 + noise_scale / 2 # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max)) # else: - return torch.floor(x + 0.5) # no need for noise when we have plenty of data + return torch.floor(x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): - return grad_output.clone()#, None + return grad_output.clone() # , None + ######################################################################################################## + class R_ENCODER(MyModule): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.Bxx = nn.BatchNorm2d(dd*64) + self.Bxx = nn.BatchNorm2d(dd * 64) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*4) - self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(dd*64) - self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 4) + self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd * 64) + self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) # self.B21 = nn.BatchNorm2d(dd*64) # self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) # self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd * 64, args.my_img_bit, kernel_size=3, padding=1) @MyFunction def forward(self, img): @@ -224,37 +231,39 @@ def forward(self, img): x = self.COUT(x + xx) return torch.sigmoid(x) + ######################################################################################################## + class R_DECODER(MyModule): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) + self.CIN = nn.Conv2d(args.my_img_bit, dd * 64, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*64) - self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 64) + self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) # self.B01 = nn.BatchNorm2d(dd*64) # self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) # self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) - self.B20 = nn.BatchNorm2d(dd*4) - self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.B20 = nn.BatchNorm2d(dd * 4) + self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) @@ -281,47 +290,52 @@ def forward(self, code): x = x + self.Cx1(ACT(self.Cx0(x))) x = self.COUT(x) - + return torch.sigmoid(x) + ########################################################################################################` + def cosine_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) - return 1 - torch.einsum('ij,ij->i',[x,y]) + return 1 - torch.einsum("ij,ij->i", [x, y]) + class RWKV_IMG(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args - + self.encoder = R_ENCODER(args) self.decoder = R_DECODER(args) self.clip_model = None clip_name = args.my_img_clip - if clip_name == 'B32': - clip_name = 'ViT-B/32' - elif clip_name == 'B16': - clip_name = 'ViT-B/16' - elif clip_name == 'L14': - clip_name = 'ViT-L/14' - elif clip_name == 'OB32': + if clip_name == "B32": + clip_name = "ViT-B/32" + elif clip_name == "B16": + clip_name = "ViT-B/16" + elif clip_name == "L14": + clip_name = "ViT-L/14" + elif clip_name == "OB32": clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" self.clip_model = CLIPModel.from_pretrained(clip_name) self.clip_model.encode_image = self.clip_model.get_image_features if self.clip_model == None: - self.clip_model, _ = clip.load(clip_name, jit = True) + self.clip_model, _ = clip.load(clip_name, jit=True) self.register_buffer( - "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) + "clip_mean", + torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1), ) self.register_buffer( - "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) + "clip_std", + torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1), ) for n, p in self.named_parameters(): - if 'clip_model' in n: + if "clip_model" in n: p.requires_grad = False self.loss_dists = DISTS() @@ -365,7 +379,7 @@ def deepspeed_offload(self) -> bool: def forward(self, img): z = self.encoder(img) - z = ToBinary.apply(z)#, self.args.my_img_noise_scale) + z = ToBinary.apply(z) # , self.args.my_img_noise_scale) out = self.decoder(z) return out @@ -379,10 +393,12 @@ def training_step(self, batch, batch_idx): if not os.path.exists(img_dir): os.makedirs(img_dir) vision.utils.save_image( - img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0 + img[:4], + f"{img_dir}/{self.trainer.global_step}-src.jpg", # , padding=0 ) vision.utils.save_image( - out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0 + out[:4], + f"{img_dir}/{self.trainer.global_step}-out.jpg", # , padding=0 ) # loss_ssim = 1 - self.loss_ssim(out, img) @@ -394,7 +410,11 @@ def training_step(self, batch, batch_idx): if args.my_img_l1_scale > 0: loss_l1 = F.l1_loss(out, img) - return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale + return ( + loss_dists + + loss_clip * args.my_img_clip_scale + + loss_l1 * args.my_img_l1_scale + ) else: return loss_dists + loss_clip * args.my_img_clip_scale @@ -418,7 +438,7 @@ def generate_init_weight(self): scale = 1 p = self.state_dict()[n] shape = p.shape - ss = n.split('.') + ss = n.split(".") # if ss[0] in ['encoder', 'decoder']: # if ss[2] == 'bias': diff --git a/RWKV-v4neo/src/model_run.py b/RWKV-v4neo/src/model_run.py index 2516e508c..184a35cfa 100644 --- a/RWKV-v4neo/src/model_run.py +++ b/RWKV-v4neo/src/model_run.py @@ -10,8 +10,12 @@ from typing import List, Dict MyModule = nn.Module + + def __nop(ob): return ob + + MyFunction = __nop # # try torchdynamo @@ -24,14 +28,17 @@ def __nop(ob): MyFunction = torch.jit.script_method RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') +print( + f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n' +) -DEBUG_TIME = False # True False - show trained time-coeffs +DEBUG_TIME = False # True False - show trained time-coeffs -RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer +RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer ############################################################################################################ + class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() @@ -41,30 +48,32 @@ def __init__(self, args): self.RUN_DEVICE = args.RUN_DEVICE with torch.no_grad(): - w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') + w = torch.load(args.MODEL_NAME + ".pth", map_location="cpu") # refine weights and send to correct device keys = list(w.keys()) - if 'pos_emb_x' in keys: - w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:] + if "pos_emb_x" in keys: + w["pos_emb"] = (w["pos_emb_x"] + w["pos_emb_y"]).reshape( + args.ctx_len + 1, -1 + )[:-1, :] keys = list(w.keys()) print_need_newline = False for x in keys: block_id = 0 - if 'blocks.' in x: - block_id = int(x.split('.')[1]) - if 'att.output.weight' in x: + if "blocks." in x: + block_id = int(x.split(".")[1]) + if "att.output.weight" in x: w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - if 'ffn.value.weight' in x: + if "ffn.value.weight" in x: w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - - if '.time_' in x: + + if ".time_" in x: w[x] = w[x].squeeze() if DEBUG_TIME: print(x, w[x].numpy()) - if '.time_decay' in x: + if ".time_decay" in x: w[x] = w[x].float() w[x] = -torch.exp(w[x]) - elif '.time_first' in x: + elif ".time_first" in x: w[x] = w[x].float() else: if self.FLOAT_MODE == "fp32": @@ -75,23 +84,27 @@ def __init__(self, args): w[x] = w[x].half() w[x].requires_grad = False - if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': + if args.RUN_DEVICE == "cuda" and x != "emb.weight": w[x] = w[x].cuda() - if ('blocks.' not in x) or ('blocks.0.' in x): + if ("blocks." not in x) or ("blocks.0." in x): if print_need_newline: - print('\n', end = '') + print("\n", end="") print_need_newline = False - print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) + print( + x.ljust(40), + str(w[x].dtype).replace("torch.", "").ljust(10), + w[x].device, + ) else: print_need_newline = True - print('.', end = '', flush = True) + print(".", end="", flush=True) # store weights in self.w keys = list(w.keys()) self.w = types.SimpleNamespace() for x in keys: - xx = x.split('.') + xx = x.split(".") here = self.w for i in range(len(xx)): if xx[i].isdigit(): @@ -103,7 +116,7 @@ def __init__(self, args): if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) @@ -119,19 +132,23 @@ def LN(self, x, w): # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp @MyFunction - def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): + def FF(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+0] = x.float() + xk = x * time_mix_k + state[5 * i + 0].type(torch.bfloat16) * ( + 1 - time_mix_k + ) + xr = x * time_mix_r + state[5 * i + 0].type(torch.bfloat16) * ( + 1 - time_mix_r + ) + state[5 * i + 0] = x.float() elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r) - state[5*i+0] = x.float() + xk = x * time_mix_k + state[5 * i + 0].half() * (1 - time_mix_k) + xr = x * time_mix_r + state[5 * i + 0].half() * (1 - time_mix_r) + state[5 * i + 0] = x.float() else: - xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) - state[5*i+0] = x + xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + state[5 * i + 0] = x r = torch.sigmoid(rw @ xr) k = torch.square(torch.relu(kw @ xk)) @@ -140,36 +157,56 @@ def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): return r * kv @MyFunction - def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + def SA( + self, + x, + state, + i: int, + time_mix_k, + time_mix_v, + time_mix_r, + time_first, + time_decay, + kw, + vw, + rw, + ow, + ): if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+1] = x.float() + xk = x * time_mix_k + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_k + ) + xv = x * time_mix_v + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_v + ) + xr = x * time_mix_r + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_r + ) + state[5 * i + 1] = x.float() elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) - state[5*i+1] = x.float() + xk = x * time_mix_k + state[5 * i + 1].half() * (1 - time_mix_k) + xv = x * time_mix_v + state[5 * i + 1].half() * (1 - time_mix_v) + xr = x * time_mix_r + state[5 * i + 1].half() * (1 - time_mix_r) + state[5 * i + 1] = x.float() else: - xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) - state[5*i+1] = x + xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + state[5 * i + 1] = x r = torch.sigmoid(rw @ xr) k = kw @ xk v = vw @ xv - if '16' in self.FLOAT_MODE: + if "16" in self.FLOAT_MODE: kk = k.float() vv = v.float() else: kk = k vv = v - aa = state[5*i+2] - bb = state[5*i+3] - pp = state[5*i+4] + aa = state[5 * i + 2] + bb = state[5 * i + 3] + pp = state[5 * i + 4] ww = time_first + kk p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) @@ -180,52 +217,72 @@ def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, ti p = torch.maximum(ww, kk) e1 = torch.exp(ww - p) e2 = torch.exp(kk - p) - state[5*i+2] = e1 * aa + e2 * vv - state[5*i+3] = e1 * bb + e2 - state[5*i+4] = p + state[5 * i + 2] = e1 * aa + e2 * vv + state[5 * i + 3] = e1 * bb + e2 + state[5 * i + 4] = p if self.FLOAT_MODE == "bf16": wkv = (a / b).type(torch.bfloat16) elif self.FLOAT_MODE == "fp16": wkv = (a / b).half() else: wkv = a / b - + return ow @ (r * wkv) - def forward(self, ctx, state, preprocess_only = False): + def forward(self, ctx, state, preprocess_only=False): with torch.no_grad(): w = self.w args = self.args x = w.emb.weight[ctx[-1]] - if self.RUN_DEVICE == 'cuda': + if self.RUN_DEVICE == "cuda": x = x.cuda() try: - pos_emb = w.pos_emb[len(ctx)-1] + pos_emb = w.pos_emb[len(ctx) - 1] x = x + pos_emb except: - pass + pass if state == None: - state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) + state = torch.zeros( + args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE + ) for i in range(args.n_layer): - state[5*i+4] -= 1e30 + state[5 * i + 4] -= 1e30 for i in range(args.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - + ww = w.blocks[i].att - x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, - ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, - ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) - + x = x + self.SA( + self.LN(x, w.blocks[i].ln1), + state, + i, + ww.time_mix_k, + ww.time_mix_v, + ww.time_mix_r, + ww.time_first, + ww.time_decay, + ww.key.weight, + ww.value.weight, + ww.receptance.weight, + ww.output.weight, + ) + ww = w.blocks[i].ffn - x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, - ww.time_mix_k, ww.time_mix_r, - ww.key.weight, ww.value.weight, ww.receptance.weight) - - if (i+1) % RWKV_RESCALE_LAYER == 0: + x = x + self.FF( + self.LN(x, w.blocks[i].ln2), + state, + i, + ww.time_mix_k, + ww.time_mix_r, + ww.key.weight, + ww.value.weight, + ww.receptance.weight, + ) + + if (i + 1) % RWKV_RESCALE_LAYER == 0: x = x / 2 if preprocess_only: diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index d5cf45212..a05ff3dc4 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -4,15 +4,17 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + def my_save(dd, ff): - if '14b-run1' not in ff: + if "14b-run1" not in ff: torch.save(dd, ff) else: - fn = ff.split('/')[-1] - fff = '/dev/shm/' + fn + fn = ff.split("/")[-1] + fff = "/dev/shm/" + fn torch.save(dd, fff) subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + class train_callback(pl.Callback): def __init__(self, args): super().__init__() @@ -37,7 +39,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if args.lr_final == 0 or args.lr_init == 0: # linear decay lr = args.lr_init + (args.lr_final - args.lr_init) * progress else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + lr = args.lr_init * math.exp( + math.log(args.lr_final / args.lr_init) * pow(progress, 1) + ) if trainer.global_step < w_step: lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) @@ -59,7 +63,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): trainer.my_loss_sum = 0 trainer.my_loss_count = 0 trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") - trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + trainer.my_log.write( + f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n" + ) try: print(f"\n{trainer.strategy.config}\n") trainer.my_log.write(f"{trainer.strategy.config}\n") @@ -69,6 +75,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if len(args.wandb) > 0: print("Login to wandb...") import wandb + wandb.init( project=args.wandb, name=args.run_name + " " + args.my_timestamp, @@ -101,19 +108,25 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # self.log("s", real_step, prog_bar=True, on_step=True) if len(args.wandb) > 0: - lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + lll = { + "loss": trainer.my_loss, + "lr": trainer.my_lr, + "Gtokens": real_step * token_per_step / 1e9, + } if kt_s > 0: lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) if args.magic_prime > 0: expand_factor = 2 if args.my_qa_mask > 0 else 1 - if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: + if ( + int(real_step) + == int(args.magic_prime * expand_factor // args.real_bsz) - 1 + ): to_save_dict = pl_module.state_dict() my_save( to_save_dict, f"{args.proj_dir}/rwkv-final.pth", ) - def on_train_epoch_start(self, trainer, pl_module): args = self.args @@ -127,12 +140,14 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module): args = self.args if trainer.is_global_zero: # logging & save state_dict - if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: - if args.data_type == 'wds_img': + if ( + args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0 + ) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == "wds_img": raw_dict = pl_module.state_dict() to_save_dict = {} for k in raw_dict: - if k.startswith('encoder.') or k.startswith('decoder.'): + if k.startswith("encoder.") or k.startswith("decoder."): to_save_dict[k] = raw_dict[k] else: to_save_dict = pl_module.state_dict() @@ -142,8 +157,10 @@ def on_train_epoch_end(self, trainer, pl_module): f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", ) except Exception as e: - print('Error\n\n', e, '\n\n') - trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + print("Error\n\n", e, "\n\n") + trainer.my_log.write( + f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n" + ) trainer.my_log.flush() trainer.my_loss_sum = 0 @@ -165,22 +182,22 @@ def generate_init_weight(model, init_weight_name): mm[k] = src.reshape(mm[k].shape) except: tmp = mm[k].squeeze().clone() - print(k, src.shape, '-->', mm[k].shape) + print(k, src.shape, "-->", mm[k].shape) ss = src.shape[0] dd = tmp.shape[0] for i in range(dd): pos = i / dd * ss if pos >= ss - 1: - tmp[i] = src[ss-1] + tmp[i] = src[ss - 1] else: p0 = int(math.floor(pos)) ii = pos - p0 - tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) + tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii) mm[k] = tmp.reshape(mm[k].shape) sss = src.squeeze().float().cpu().numpy() - print(sss[:10], '...', sss[-10:]) + print(sss[:10], "...", sss[-10:]) mmm = mm[k].squeeze().float().cpu().numpy() - print(mmm[:10], '...', mmm[-10:]) + print(mmm[:10], "...", mmm[-10:]) print(f"Save to {init_weight_name}...") torch.save(mm, init_weight_name) diff --git a/RWKV-v4neo/src/utils.py b/RWKV-v4neo/src/utils.py index ea25990b4..87da098db 100644 --- a/RWKV-v4neo/src/utils.py +++ b/RWKV-v4neo/src/utils.py @@ -6,6 +6,7 @@ time_slot = {} time_ref = time.time_ns() + def record_time(name): if name not in time_slot: time_slot[name] = 1e20 @@ -13,20 +14,23 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - if 'list' in str(type(WORD_NAME)): + +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): self.charMode = False if WORD_NAME[0] == WORD_NAME[1]: from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) else: from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) else: self.charMode = True - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -37,23 +41,25 @@ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(out, dim=-1) if self.charMode: - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual @@ -81,6 +87,7 @@ def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_ out = torch.multinomial(probs, num_samples=1)[0] return out + def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): return True @@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number): if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): iterationNumber = 1 - while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): randomNumberWithPower = pow(randomNumberWithPower, 2, number) iterationNumber = iterationNumber + 1 if randomNumberWithPower != (number - 1): diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 03eda09ad..82ad298bd 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -50,53 +50,87 @@ parser = ArgumentParser() parser.add_argument("--load_model", default="", type=str) # full path, with .pth - parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument( + "--wandb", default="", type=str + ) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_type", default="utf-8", type=str) - parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + parser.add_argument( + "--vocab_size", default=0, type=int + ) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps - parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final - parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x - parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" - - parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument( + "--epoch_steps", default=1000, type=int + ) # a mini "epoch" has [epoch_steps] steps + parser.add_argument( + "--epoch_count", default=500, type=int + ) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument( + "--epoch_begin", default=0, type=int + ) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument( + "--epoch_save", default=5, type=int + ) # save the model every [epoch_save] "epochs" + + parser.add_argument( + "--micro_bsz", default=12, type=int + ) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--dim_att", default=0, type=int) parser.add_argument("--dim_ffn", default=0, type=int) - parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument( + "--pre_ffn", default=0, type=int + ) # replace first att layer by ffn (sometimes better) parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim - parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + parser.add_argument( + "--tiny_att_layer", default=-999, type=int + ) # tiny attention @ which layer - parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument( + "--lr_init", default=6e-4, type=float + ) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) - parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument( + "--warmup_steps", default=0, type=int + ) # try 50 if you load a model parser.add_argument("--beta1", default=0.9, type=float) - parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument( + "--beta2", default=0.99, type=float + ) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument( + "--grad_cp", default=0, type=int + ) # gradient checkpt: saves VRAM, but slower - parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version + parser.add_argument( + "--my_pile_version", default=1, type=int + ) # my special pile version parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode - parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument( + "--my_pile_shift", default=-1, type=int + ) # my special pile mode - text shift parser.add_argument("--my_pile_edecay", default=0, type=int) - parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) - parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + parser.add_argument( + "--layerwise_lr", default=1, type=int + ) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument( + "--ds_bucket_mb", default=200, type=int + ) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) parser.add_argument("--my_img_version", default=0, type=str) parser.add_argument("--my_img_size", default=0, type=int) parser.add_argument("--my_img_bit", default=0, type=int) - parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip", default="x", type=str) parser.add_argument("--my_img_clip_scale", default=1, type=float) parser.add_argument("--my_img_l1_scale", default=0, type=float) - parser.add_argument("--my_img_encoder", default='x', type=str) + parser.add_argument("--my_img_encoder", default="x", type=str) # parser.add_argument("--my_img_noise_scale", default=0, type=float) parser.add_argument("--my_sample_len", default=0, type=int) parser.add_argument("--my_ffn_shift", default=1, type=int) @@ -105,7 +139,7 @@ parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int) - parser.add_argument("--my_testing", default='', type=str) + parser.add_argument("--my_testing", default="", type=str) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -116,18 +150,26 @@ import numpy as np import torch from torch.utils.data import DataLoader + if "deepspeed" in args.strategy: import deepspeed import pytorch_lightning as pl from pytorch_lightning import seed_everything if args.random_seed >= 0: - print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + print( + f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" + * 3 + ) seed_everything(args.random_seed) np.set_printoptions(precision=4, suppress=True, linewidth=200) - warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") - warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + warnings.filterwarnings( + "ignore", ".*Consider increasing the value of the `num_workers` argument*" + ) + warnings.filterwarnings( + "ignore", ".*The progress bar already tracks a metric with the*" + ) # os.environ["WDS_SHOW_SEED"] = "1" args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -152,7 +194,9 @@ args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" args.proj_dir = f"{args.proj_dir}-{args.run_name}" else: - args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + args.run_name = ( + f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + ) if not os.path.exists(args.proj_dir): os.makedirs(args.proj_dir) @@ -240,18 +284,32 @@ ) rank_zero_info(str(vars(args)) + "\n") - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + assert args.data_type in [ + "utf-8", + "utf-16le", + "numpy", + "binidx", + "dummy", + "wds_img", + "uint16", + ] if args.lr_final == 0 or args.lr_init == 0: - rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + rank_zero_info( + "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n" + ) assert args.precision in ["fp32", "tf32", "fp16", "bf16"] os.environ["RWKV_FLOAT_MODE"] = args.precision if args.precision == "fp32": for i in range(10): - rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n" + ) if args.precision == "fp16": - rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n" + ) os.environ["RWKV_JIT_ON"] = "1" if "deepspeed_stage_3" in args.strategy: @@ -281,14 +339,18 @@ train_data = MyDataset(args) args.vocab_size = train_data.vocab_size - if args.data_type == 'wds_img': + if args.data_type == "wds_img": from src.model_img import RWKV_IMG + model = RWKV_IMG(args) else: from src.model import RWKV + model = RWKV(args) - if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? + if ( + len(args.load_model) == 0 or args.my_pile_stage == 1 + ): # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth" generate_init_weight(model, init_weight_name) # save initial weights args.load_model = init_weight_name @@ -330,10 +392,22 @@ print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) # must set shuffle=False, persistent_workers=False (because worker is in another thread) - data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + data_loader = DataLoader( + train_data, + shuffle=False, + pin_memory=True, + batch_size=args.micro_bsz, + num_workers=1, + persistent_workers=False, + drop_last=True, + ) trainer.fit(model, data_loader) diff --git a/RWKV-v4neo/verify.py b/RWKV-v4neo/verify.py index 4f56e392f..695e651f2 100644 --- a/RWKV-v4neo/verify.py +++ b/RWKV-v4neo/verify.py @@ -7,6 +7,7 @@ import os, sys, types import numpy as np import torch + np.set_printoptions(precision=4, suppress=True, linewidth=200) try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] @@ -16,23 +17,24 @@ torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32 -os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA -RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] +os.environ["RWKV_FLOAT_MODE"] = "bf16" # bf16 or fp32 +os.environ["RWKV_RUN_DEVICE"] = "cuda" # currently model_train requires CUDA +RUN_DEVICE = os.environ["RWKV_RUN_DEVICE"] -TOKEN_MODE = 'pile' +TOKEN_MODE = "pile" -if TOKEN_MODE == 'pile': - WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] - MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' +if TOKEN_MODE == "pile": + WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"] + MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783" n_layer = 32 n_embd = 2560 ctx_len = 1024 UNKNOWN_CHAR = None from src.utils import TOKENIZER + tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) -if TOKEN_MODE == 'pile': +if TOKEN_MODE == "pile": tokenizer.vocab_size = 50277 ######################################################################################################## @@ -54,23 +56,23 @@ args.my_pos_emb = 0 model_train = RWKV(args).to(RUN_DEVICE) -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': +if os.environ["RWKV_FLOAT_MODE"] == "fp16": model_train = model_train.half() -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': +elif os.environ["RWKV_FLOAT_MODE"] == "bf16": model_train = model_train.bfloat16() -print('loading ' + MODEL_NAME) -m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu') +print("loading " + MODEL_NAME) +m2 = torch.load(MODEL_NAME + ".pth", map_location="cpu") model_train.load_state_dict(m2) -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': +if os.environ["RWKV_FLOAT_MODE"] == "fp16": model_train = model_train.half() -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': +elif os.environ["RWKV_FLOAT_MODE"] == "bf16": model_train = model_train.bfloat16() args.MODEL_NAME = MODEL_NAME args.RUN_DEVICE = RUN_DEVICE -args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE'] +args.FLOAT_MODE = os.environ["RWKV_FLOAT_MODE"] model_rnn = RWKV_RNN(args) ######################################################################################################## @@ -78,27 +80,33 @@ print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}") # context = '\nIn a' -context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." -if TOKEN_MODE == 'pile': +if TOKEN_MODE == "pile": ctx = tokenizer.tokenizer.encode(context) -print(f'input len {len(ctx)} data {ctx}') +print(f"input len {len(ctx)} data {ctx}") ######################################################################################################## with torch.no_grad(): - print('\nRWKV-train output') - out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy() - print(out, '\n') - - print('\nRWKV-RNN output') + print("\nRWKV-train output") + out = ( + model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0] + .detach() + .cpu() + .float() + .numpy() + ) + print(out, "\n") + + print("\nRWKV-RNN output") state = None out = None src_len = len(ctx) for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] out, state = model_rnn.forward(x, state) if i < 3 or i >= src_len - 3: print(out.detach().cpu().numpy()) if i == 2: - print('...') + print("...")