diff --git a/RWKV-v7/rwkv_v8_rc00_demo.py b/RWKV-v7/rwkv_v8_rc00_demo.py index 186426eb4..94fb2b660 100644 --- a/RWKV-v7/rwkv_v8_rc00_demo.py +++ b/RWKV-v7/rwkv_v8_rc00_demo.py @@ -106,6 +106,7 @@ def __init__(self, args): z['blocks.0.att.v0'] = z['blocks.0.att.a0'] # actually ignored z['blocks.0.att.v1'] = z['blocks.0.att.a1'] # actually ignored z['blocks.0.att.v2'] = z['blocks.0.att.a2'] # actually ignored + self.deepembs = {} def forward(self, idx, state, full_output=False): if state == None: @@ -178,7 +179,15 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool= xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) - xx, state[i*3+2] = RWKV_x080_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'enn.weight'][idx]) + # Disk space offloading + if not (ffn+'enn.weight' in self.deepembs): + z[ffn+'enn.weight'].numpy().tofile(ffn+'enn.weight_storage'+'.pt') + emb_size = z[ffn+'enn.weight'].numel() + emb_dtype = z[ffn+'enn.weight'].dtype + z[ffn+'enn.weight'] = torch.from_file(ffn+'enn.weight_storage'+'.pt', size=emb_size, dtype=emb_dtype, device='cpu') + self.deepembs[ffn+'enn.weight'] = True + + xx, state[i*3+2] = RWKV_x080_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'enn.weight'][idx].cuda()) x = x + xx if not full_output: x = x[-1,:]