diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index c2f7c3f38..86a211def 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -855,12 +855,12 @@ def forward(self, x, v_first): H = self.n_head xx = self.time_shift(x) - x - xr = x + xx * self.x_r - xw = x + xx * self.x_w - xk = x + xx * self.x_k - xv = x + xx * self.x_v - xa = x + xx * self.x_a - xg = x + xx * self.x_g + xr = torch.addcmul(x, xx, self.x_r) + xw = torch.addcmul(x, xx, self.x_w) + xk = torch.addcmul(x, xx, self.x_k) + xv = torch.addcmul(x, xx, self.x_v) + xa = torch.addcmul(x, xx, self.x_a) + xg = torch.addcmul(x, xx, self.x_g) r = self.receptance(xr) w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5) @@ -869,13 +869,13 @@ def forward(self, x, v_first): if self.layer_id == 0: v_first = v # store the v of the first layer else: - v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual + v = torch.lerp(v, v_first, torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)) # add value residual a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" g = torch.sigmoid(xg @ self.g1) @ self.g2 kk = k * self.k_k kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C) - k = k * (1 + (a-1) * self.k_a) + k = k.addcmul(k * (a - 1), self.k_a) x = RUN_CUDA_RWKV7g(r, w, k, v, -kk, kk*a) x = self.ln_x(x.view(B * T, C)).view(B, T, C) @@ -970,7 +970,7 @@ def __init__(self, args, layer_id): def forward(self, x): xx = self.time_shift(x) - x - k = x + xx * self.x_k + k = torch.addcmul(x, xx, self.x_k) k = torch.relu(self.key(k)) ** 2 return self.value(k)