From 65cf242dc9d62e4df2c2e91da7655ad2a7aba23a Mon Sep 17 00:00:00 2001 From: Harrison Date: Fri, 18 Aug 2023 20:42:06 +1000 Subject: [PATCH] add resx --- RWKV-v4neo/src/model.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 611a5e875..14307c452 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -632,18 +632,30 @@ def forward(self, idx): if args.dropout > 0: x = self.drop0(x) - if args.tiny_att_dim > 0: - for block in self.blocks: - if args.grad_cp == 1: - x = deepspeed.checkpointing.checkpoint(block, x, x_emb) - else: - x = block(x, x_emb) - else: - for block in self.blocks: - if args.grad_cp == 1: - x = deepspeed.checkpointing.checkpoint(block, x) + + + # resx is an exponential residual connection between layers. + # It cuts early training speed in half + resx = "x" in os.environ["RWKV_MY_TESTING"] + resxstack = None + + usex_emb = args.tiny_att_dim > 0 + + for block in self.blocks: + + if resx: + if resxstack is None: + resxstack = x else: - x = block(x) + resxstack = resxstack*2 + x + + x = resxstack + x + + if args.grad_cp == 1: + x = deepspeed.checkpointing.checkpoint(block, *([x, x_emb] if usex_emb else [x])) + else: + x = block(*([x, x_emb] if usex_emb else [x])) + x = self.ln_out(x)