diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 837d762df..236843eb4 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -172,7 +172,12 @@ def __init__(self, args, layer_id): self.time_first = nn.Parameter(torch.ones(self.n_head) * math.log(0.3)) - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + shiftexponent = layer_id + if(shiftexponent >= 12): + shiftexponent = 0 + shiftamount = 2**shiftexponent + self.time_shift = nn.ZeroPad2d((0, 0, shiftamount, -shiftamount)) + self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)