diff --git a/RWKV-v7/ascend/CMakeLists.txt b/RWKV-v7/ascend/CMakeLists.txt new file mode 100644 index 000000000..9132c2182 --- /dev/null +++ b/RWKV-v7/ascend/CMakeLists.txt @@ -0,0 +1,76 @@ +cmake_minimum_required(VERSION 3.16.0) +project(Ascend_C) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# user-defined configuration +set(SOC_VERSION "Ascend910B3" CACHE STRING "system on chip type") +set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "ASCEND CANN package installation directory") +set(RUN_MODE "npu" CACHE STRING "run mode: npu") +set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type Release/Debug (default Debug)" FORCE) +set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRING "path for install()" FORCE) + +if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) +else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.") +endif() + +include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + +# ascendc_library use to add kernel file to generate ascendc library +ascendc_library(kernels STATIC + wkv7s.cc +) + +add_library(pybind11_lib SHARED wkv7s_op.cpp) +target_link_libraries(pybind11_lib PRIVATE + kernels + torch_npu +) +execute_process(COMMAND python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_PATH +) +message("TORCH_PATH is ${TORCH_PATH}") +set(ENV{ASCEND_HOME_PATH} ${ASCEND_CANN_PACKAGE_PATH}) +execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_NPU_PATH +) +message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") +target_link_directories(pybind11_lib PRIVATE + ${TORCH_PATH}/lib + ${TORCH_NPU_PATH}/lib +) +target_link_options(pybind11_lib PRIVATE + -Wl,-rpath,${TORCH_PATH}/lib + -Wl,-rpath,${TORCH_NPU_PATH}/lib +) +target_include_directories(pybind11_lib PRIVATE + ${TORCH_NPU_PATH}/include + ${TORCH_PATH}/include + ${TORCH_PATH}/include/torch/csrc/api/include +) +execute_process(COMMAND python3 -m pybind11 --includes + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYBIND11_INC +) +string(REPLACE " " ";" PYBIND11_INC ${PYBIND11_INC}) +target_compile_options(pybind11_lib PRIVATE + ${PYBIND11_INC} + -D_GLIBCXX_USE_CXX11_ABI=0 +) + +execute_process(COMMAND python3-config --extension-suffix + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYBIND11_SUFFIX +) +set_target_properties(pybind11_lib PROPERTIES + OUTPUT_NAME wkv7s${PYBIND11_SUFFIX} + PREFIX "" SUFFIX "" +) diff --git a/RWKV-v7/ascend/README.md b/RWKV-v7/ascend/README.md new file mode 100644 index 000000000..b7aa44b18 --- /dev/null +++ b/RWKV-v7/ascend/README.md @@ -0,0 +1,18 @@ +# 1. 编译算子 +``` +mkdir build +cd build +cmake .. +make +``` + +# 2. 测试算子 +``` +python test_rwkv.py +``` + +# 3. 运行模型 +``` +cd .. +python rwkv_v7_demo_fast_npu.py +``` \ No newline at end of file diff --git a/RWKV-v7/ascend/test_rwkv.py b/RWKV-v7/ascend/test_rwkv.py new file mode 100644 index 000000000..c79e71246 --- /dev/null +++ b/RWKV-v7/ascend/test_rwkv.py @@ -0,0 +1,101 @@ +import torch +import torch_npu +import sys +sys.path.append("./build") +import wkv7s + +HEAD_SIZE = 64 +DTYPE = torch.float16 + +# load(name="wkv7s", sources=["wkv7s_op.cpp", f"wkv7s.cu"], is_python_module=False, +# verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) +class WKV_7(torch.autograd.Function): + @staticmethod + def forward(ctx, state, r, w, k, v, a, b): + with torch.no_grad(): + B, T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + assert HEAD_SIZE == C // H + assert r.dtype == DTYPE + assert w.dtype == DTYPE + assert k.dtype == DTYPE + assert v.dtype == DTYPE + assert a.dtype == DTYPE + assert b.dtype == DTYPE + assert r.is_contiguous() + assert w.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert a.is_contiguous() + assert b.is_contiguous() + y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format) + wkv7s.forward(B, T, C, H, state, r, w, k, v, a, b, y) + return y + +def RWKV7_OP_KERNEL(state, r, w, k, v, a, b): + return WKV_7.apply(state, r, w, k, v, a, b) + + +def RWKV7_OP_TORCH(state, r, w, k, v, a, b): + B, T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + r = r.view(B, T, H, N).float() + k = k.view(B, T, H, N).float() + v = v.view(B, T, H, N).float() + a = a.view(B, T, H, N).float() + b = b.view(B, T, H, N).float() + w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) + out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) + + for t in range(T): + kk = k[:, t, :].view(B, H, 1, N) + rr = r[:, t, :].view(B, H, N, 1) + vv = v[:, t, :].view(B, H, N, 1) + aa = a[:, t, :].view(B, H, N, 1) + bb = b[:, t, :].view(B, H, 1, N) + state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk + out[:, t, :] = (state @ rr).view(B, H, N) + + return out.view(B, T, C).to(dtype=DTYPE), state + + +if __name__ == "__main__": + device = "npu" + B = 1 + T = 1 + C = 1024 + + torch.manual_seed(42) + torch.set_printoptions(precision=4, sci_mode=False) + + r = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + w = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + k = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + v = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + a = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + b = torch.randn(B, T, C, dtype=DTYPE, device=device).contiguous() + state = torch.randn(B, C // HEAD_SIZE, HEAD_SIZE, HEAD_SIZE, dtype=torch.float, device=device) + + with torch.no_grad(): + y_torch, state_torch = RWKV7_OP_TORCH(state.clone(), r, w, k, v, a, b) + torch.npu.synchronize() + state_kernel = state.clone() + y_kernel = RWKV7_OP_KERNEL(state_kernel, r, w, k, v, a, b) + + print(r[0][0][:64]) + print(state[0][0][0][:64]) + + print(state_torch[0][0][0][:64]) + print(state_kernel[0][0][0][:64]) + + print(y_torch[0][0][:64]) + print(y_kernel[0][0][:64]) + + + # === 比较结果 === + abs_diff = (y_kernel - y_torch).abs().float() + max_diff = abs_diff.max().item() + print("Max absolute difference:", max_diff) + print("All close (atol=1e-3):", torch.allclose(y_kernel, y_torch, atol=1e-3, rtol=1e-3)) \ No newline at end of file diff --git a/RWKV-v7/ascend/wkv7s.cc b/RWKV-v7/ascend/wkv7s.cc new file mode 100644 index 000000000..593e3eb9e --- /dev/null +++ b/RWKV-v7/ascend/wkv7s.cc @@ -0,0 +1,215 @@ +#include "kernel_operator.h" +#include + +using namespace AscendC; + +const int BUFFER_NUM = 2; + +template +class RWKV +{ +public: + __aicore__ inline RWKV() {} + __aicore__ inline void init(const int B, const int T, const int C, const int H, + GM_ADDR state, + GM_ADDR r, GM_ADDR w, GM_ADDR k, + GM_ADDR v, GM_ADDR a, GM_ADDR b, + GM_ADDR y); + __aicore__ inline void process(); + +private: + __aicore__ inline void copyIn(size_t t); + __aicore__ inline void compute(size_t t); + __aicore__ inline void copyOut(size_t t); + +private: + // gm input + GlobalTensor _r, _w, _k, _v, _a, _b; + // gm output + GlobalTensor _y; + // gm state + GlobalTensor _state; + + // que input + TQue _que_r, _que_w, _que_k, _que_v, _que_a, _que_b; + // que output + TQue _que_y; + + // que state + TBuf _buf_state; + + TPipe _pipe; + int _B, _T, _C, _H, _N; + int _e; // batch_idx + int _h; // head_idx; +}; + +template +__aicore__ inline void RWKV::init(const int B, const int T, const int C, const int H, + GM_ADDR state, + GM_ADDR r, GM_ADDR w, GM_ADDR k, + GM_ADDR v, GM_ADDR a, GM_ADDR b, + GM_ADDR y) +{ + _B = B; + _T = T; + _C = C; + _H = H; + _N = _C / _H; + + _e = GetBlockIdx() / _H; // batch_idx + _h = GetBlockIdx() % _H; // head_idx + + int rwkv_bias = _e * _T * _C + _h * _N; + int state_bias = _e * _N * _N * _H + _h * _N * _N; + + // set gm + _state.SetGlobalBuffer((__gm__ float *)state + state_bias); + _r.SetGlobalBuffer((__gm__ F *)r + rwkv_bias); + _w.SetGlobalBuffer((__gm__ F *)w + rwkv_bias); + _k.SetGlobalBuffer((__gm__ F *)k + rwkv_bias); + _v.SetGlobalBuffer((__gm__ F *)v + rwkv_bias); + _a.SetGlobalBuffer((__gm__ F *)a + rwkv_bias); + _b.SetGlobalBuffer((__gm__ F *)b + rwkv_bias); + _y.SetGlobalBuffer((__gm__ F *)y + rwkv_bias); + + // init que + _pipe.InitBuffer(_que_r, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_w, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_k, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_v, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_a, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_b, BUFFER_NUM, _N * sizeof(F)); + _pipe.InitBuffer(_que_y, BUFFER_NUM, _N * sizeof(F)); + + // init buf + _pipe.InitBuffer(_buf_state, _N * _N * sizeof(float)); +} + +template +__aicore__ inline void RWKV::copyIn(size_t t) +{ + LocalTensor r = _que_r.AllocTensor(); + LocalTensor w = _que_w.AllocTensor(); + LocalTensor k = _que_k.AllocTensor(); + LocalTensor v = _que_v.AllocTensor(); + LocalTensor a = _que_a.AllocTensor(); + LocalTensor b = _que_b.AllocTensor(); + + DataCopy(r, _r[t * _C], _N); + DataCopy(w, _w[t * _C], _N); + DataCopy(k, _k[t * _C], _N); + DataCopy(v, _v[t * _C], _N); + DataCopy(a, _a[t * _C], _N); + DataCopy(b, _b[t * _C], _N); + + _que_r.EnQue(r); + _que_w.EnQue(w); + _que_k.EnQue(k); + _que_v.EnQue(v); + _que_a.EnQue(a); + _que_b.EnQue(b); +} + +template +__aicore__ inline void RWKV::compute(size_t t) +{ + // get input + LocalTensor r = _que_r.DeQue(); + LocalTensor w = _que_w.DeQue(); + LocalTensor k = _que_k.DeQue(); + LocalTensor v = _que_v.DeQue(); + LocalTensor a = _que_a.DeQue(); + LocalTensor b = _que_b.DeQue(); + + // get state + LocalTensor state = _buf_state.Get(); + + // get output + LocalTensor y = _que_y.AllocTensor(); + + // compute w + Exp(w, w, _N); + Muls(w, w, F(-1.0), _N); + Exp(w, w, _N); + + // compute + for (int i = 0; i < _N; i++) + { + float sa = 0; + // PipeBarrier(); +#pragma unroll + for (int j = 0; j < _N; j++) + { + sa += float(a(j)) * state(i * _N + j); + } + + float vv = float(v(i)); + float yy = 0; +#pragma unroll + for (int j = 0; j < _N; j++) + { + // __ubuf__ float *s = &state(i * _N + j); + // *s = *s * float(w(j)) + float(k(j)) * vv + sa * float(b(j)); + // yy += *s * float(r(j)); + float s = state(i * _N + j); + s = s * float(w(j)) + float(k(j)) * vv + sa * float(b(j)); + yy += s * float(r(j)); + state.SetValue(i * _N + j, s); + } + // PipeBarrier(); + y(i) = F(yy); + } + + _que_r.FreeTensor(r); + _que_w.FreeTensor(w); + _que_k.FreeTensor(k); + _que_v.FreeTensor(v); + _que_a.FreeTensor(a); + _que_b.FreeTensor(b); + _que_y.EnQue(y); +} + +template +__aicore__ inline void RWKV::copyOut(size_t t) +{ + LocalTensor y = _que_y.DeQue(); + DataCopy(_y[t * _C], y, _N); + _que_y.FreeTensor(y); +} + +template +__aicore__ inline void RWKV::process() +{ + LocalTensor state_local = _buf_state.Get(); + DataCopy(state_local, _state, _N * _N); + for (int t = 0; t < _T; t++) + { + copyIn(t); + compute(t); + copyOut(t); + } + DataCopy(_state, state_local, _N * _N); +} + +extern "C" __global__ __aicore__ void kernel_forward(int B, int T, int C, int H, + GM_ADDR state, + GM_ADDR r, + GM_ADDR w, + GM_ADDR k, + GM_ADDR v, + GM_ADDR a, + GM_ADDR b, + GM_ADDR y) +{ + RWKV rwkv; + rwkv.init(B, T, C, H, state, r, w, k, v, a, b, y); + rwkv.process(); +} + +void ascend_forward(int B, int T, int C, int H, void *state, void *r, void *w, void *k, void *v, void *a, void *b, void *y, void* stream) +{ + assert(C % H == 0); + assert(B == 1); // only for B=1 + kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y); +} diff --git a/RWKV-v7/ascend/wkv7s_op.cpp b/RWKV-v7/ascend/wkv7s_op.cpp new file mode 100644 index 000000000..9de164590 --- /dev/null +++ b/RWKV-v7/ascend/wkv7s_op.cpp @@ -0,0 +1,23 @@ +#include +#include +#include "ATen/ATen.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" + +// typedef at::BFloat16 bf16; + +void ascend_forward(int B, int T, int C, int H, void *state, void *r, void *w, void *k, void *v, void *a, void *b, void *y, void* stream); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, + torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, + torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) +{ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); + ascend_forward(B, T, C, H, + state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), + v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr(), acl_stream); +} + +PYBIND11_MODULE(wkv7s, m) +{ + m.def("forward", &forward, "Forward with original order"); +} \ No newline at end of file diff --git a/RWKV-v7/rwkv_v7_demo_fast_npu.py b/RWKV-v7/rwkv_v7_demo_fast_npu.py new file mode 100644 index 000000000..a362c42a2 --- /dev/null +++ b/RWKV-v7/rwkv_v7_demo_fast_npu.py @@ -0,0 +1,481 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## +# +# This version is GPT-mode + RNN-mode, and a bit more difficult to understand +# +######################################################################################################## +import sys +sys.path.append("/workspace/RWKV-LM/RWKV-v7/ascend/build") +import wkv7s + +import numpy as np +np.set_printoptions(precision=4, suppress=True, linewidth=200) +import types, torch, copy, time +import torch_npu +from typing import List, Set +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True +torch._C._jit_set_autocast_mode(False) + +import torch.nn as nn +from torch.nn import functional as F + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +MyStatic = torch.jit.script +# MyModule = nn.Module +# def __nop(ob): return ob +# MyFunction = __nop +# MyStatic = __nop + +######################################################################################################## + +args = types.SimpleNamespace() + +# model download: https://huggingface.co/BlinkDL/rwkv-7-world + +args.MODEL_NAME = "/workspace/models/RWKV-x070-World-0.1B-v2.8-20241210-ctx4096" + +args.n_layer = 12 +args.n_embd = 768 +args.vocab_size = 65536 +args.head_size = 64 + +prompt = "The Eiffel tower is in the city of" +NUM_TRIALS = 3 +LENGTH_PER_TRIAL = 100 +TEMPERATURE = 1.0 +TOP_P = 0.0 + +######################################################################################################## +# +# The RWKV-7 "Goose" Language Model - https://github.com/BlinkDL/RWKV-LM +# +######################################################################################################## + +DTYPE = torch.half + +# from torch.utils.cpp_extension import load +HEAD_SIZE = args.head_size + +class WKV_7(torch.autograd.Function): + @staticmethod + def forward(ctx, state, r, w, k, v, a, b): + with torch.no_grad(): + T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + assert HEAD_SIZE == C // H + assert r.dtype == DTYPE + assert w.dtype == DTYPE + assert k.dtype == DTYPE + assert v.dtype == DTYPE + assert a.dtype == DTYPE + assert b.dtype == DTYPE + assert r.is_contiguous() + assert w.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert a.is_contiguous() + assert b.is_contiguous() + y = torch.empty((T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format) + wkv7s.forward(1, T, C, H, state, r, w, k, v, a, b, y) + return y + +def RWKV7_OP_KERNEL(state, r, w, k, v, a, b): + return WKV_7.apply(state, r, w, k, v, a, b) + +######################################################################################################## + +class RWKV_x070(MyModule): + def __init__(self, args): + super().__init__() + self.args = args + self.n_embd = args.n_embd + self.n_layer = args.n_layer + self.eval() + + self.z = torch.load(args.MODEL_NAME + '.pth', map_location='npu') + z = self.z + self.n_head, self.head_size = z['blocks.0.att.r_k'].shape + + keys = list(z.keys()) + for k in keys: + if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k: + z[k] = z[k].t() + z[k] = z[k].squeeze().to(dtype=DTYPE) + if k.endswith('att.r_k'): z[k] = z[k].flatten() + assert self.head_size == args.head_size + + z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) + z['blocks.0.att.v0'] = z['blocks.0.att.a0'] # actually ignored + z['blocks.0.att.v1'] = z['blocks.0.att.a1'] # actually ignored + z['blocks.0.att.v2'] = z['blocks.0.att.a2'] # actually ignored + + def forward(self, idx, state, full_output=False): + if state == None: + state = [None for _ in range(args.n_layer * 3)] + for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev + state[i*3+0] = torch.zeros(args.n_embd, dtype=DTYPE, requires_grad=False, device="npu") + state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="npu") + state[i*3+2] = torch.zeros(args.n_embd, dtype=DTYPE, requires_grad=False, device="npu") + + if type(idx) is list: + if len(idx) > 1: + return self.forward_seq(idx, state, full_output) + else: + return self.forward_one(idx[0], state) + else: + return self.forward_one(idx, state) + + @MyFunction + def forward_one(self, idx:int, state:List[torch.Tensor]): + with torch.no_grad(): + z = self.z + x = z['emb.weight'][idx] + + v_first = torch.empty_like(x) + for i in range(self.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) + + xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_one(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1], + z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'], + z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'], + z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'], + z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'], + z[att+'ln_x.weight'], z[att+'ln_x.bias']) + x = x + xx + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) + + xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + x = x + xx + + x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) + x = x @ z['head.weight'] + return x, state + + # @MyFunction + def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=False): + with torch.no_grad(): + z = self.z + x = z['emb.weight'][idx] + + v_first = torch.empty_like(x) + for i in range(self.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) + + xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_seq(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1], + z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'], + z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'], + z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'], + z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'], + z[att+'ln_x.weight'], z[att+'ln_x.bias']) + x = x + xx + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) + + xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + x = x + xx + + if not full_output: x = x[-1,:] + x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) + x = x @ z['head.weight'] + return x, state + +######################################################################################################## + +@MyStatic +def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b): + xx = x_prev - x + xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g + + r = xr @ R_ + w = torch.tanh(xw @ w1) @ w2 + k = xk @ K_ + v = xv @ V_ + a = torch.sigmoid(a0 + (xa @ a1) @ a2) + g = torch.sigmoid(xg @ g1) @ g2 + + kk = torch.nn.functional.normalize((k * k_k).view(H,N), dim=-1, p=2.0).view(H*N) + k = k * (1 + (a-1) * k_a) + if layer_id == 0: v_first = v + else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) + w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5) + + vk = v.view(H,N,1) @ k.view(H,1,N) + ab = (-kk).view(H,N,1) @ (kk*a).view(H,1,N) + state = state * w.view(H,1,N) + state @ ab.float() + vk.float() + xx = (state.to(dtype=x.dtype) @ r.view(H,N,1)) + + xx = torch.nn.functional.group_norm(xx.view(1,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) + xx = xx + ((r * k * r_k).view(H,N).sum(dim=-1, keepdim=True) * v.view(H,N)).view(H*N) + return (xx * g) @ O_, x, state, v_first + +# @MyStatic +def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b): + T = x.shape[0] + xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x + xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g + + r = xr @ R_ + w = torch.tanh(xw @ w1) @ w2 + k = xk @ K_ + v = xv @ V_ + a = torch.sigmoid(a0 + (xa @ a1) @ a2) + g = torch.sigmoid(xg @ g1) @ g2 + + kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N) + k = k * (1 + (a-1) * k_a) + if layer_id == 0: v_first = v + else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) + + ######## cuda-free method + # w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5) + # for t in range(T): + # r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t] + # vk = v_.view(H,N,1) @ k_.view(H,1,N) + # ab = (-kk_).view(H,N,1) @ (kk_*a_).view(H,1,N) + # state = state * w_.view(H,1,N) + state @ ab.float() + vk.float() + # xx[t] = (state.to(dtype=x.dtype) @ r_.view(H,N,1)).view(H*N) + + w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5 + torch.npu.synchronize() + xx = RWKV7_OP_KERNEL(state, r, w, k, v, -kk, kk*a) + torch.npu.synchronize() + # print(xx[0][:64]) + + xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N) + xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N) + return (xx * g) @ O_, x[-1,:], state, v_first + +######################################################################################################## + +@MyStatic +def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_): + xx = x_prev - x + k = x + xx * x_k + k = torch.relu(k @ K_) ** 2 + return k @ V_, x + +@MyStatic +def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_): + xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x + k = x + xx * x_k + k = torch.relu(k @ K_) ** 2 + return k @ V_, x[-1,:] + +######################################################################################################## +# +# The testing code +# +######################################################################################################## + +@MyStatic +def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0): + probs = F.softmax(logits.float(), dim=-1) + sorted_probs, sorted_ids = torch.sort(probs, descending=True) + + if top_k > 0: + probs[sorted_ids[top_k:]] = 0 + + if top_p < 1: + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + cutoff_index = torch.searchsorted(cumulative_probs, top_p) + cutoff = sorted_probs[cutoff_index] + probs[probs < cutoff] = 0 + + if top_p > 0: + idx = torch.where(probs == cutoff)[0] + if len(idx) > 0: + probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx) + # assert abs(torch.sum(probs).item() - top_p) < 1e-6 + + if temperature != 1.0: + probs = probs ** (1.0 / temperature) + + return torch.multinomial(probs, num_samples=1).item() + +######################################################################################################## +# RWKV Tokenizer (slow version) +######################################################################################################## + +class RWKV_TOKENIZER(): + table: List[List[List[bytes]]] + good: List[Set[int]] + wlen: List[int] + def __init__(self, file_name): + self.idx2token = {} + sorted = [] # must be already sorted + lines = open(file_name, "r", encoding="utf-8").readlines() + for l in lines: + idx = int(l[:l.index(' ')]) + x = eval(l[l.index(' '):l.rindex(' ')]) + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(' '):]) + sorted += [x] + self.idx2token[idx] = x + + self.token2idx = {} + for k, v in self.idx2token.items(): + self.token2idx[v] = int(k) + + # precompute some tables for fast matching + self.table = [[[] for j in range(256)] for i in range(256)] + self.good = [set() for i in range(256)] + self.wlen = [0 for i in range(256)] + + for i in reversed(range(len(sorted))): # reverse order - match longer tokens first + s = sorted[i] + if len(s) >= 2: + s0 = int(s[0]) + s1 = int(s[1]) + self.table[s0][s1] += [s] + self.wlen[s0] = max(self.wlen[s0], len(s)) + self.good[s0].add(s1) + + def encodeBytes(self, src: bytes) -> List[int]: + src_len: int = len(src) + tokens: List[int] = [] + i: int = 0 + while i < src_len: + s: bytes = src[i : i + 1] + + if i < src_len - 1: + s1: int = int(src[i + 1]) + s0: int = int(src[i]) + if s1 in self.good[s0]: + sss: bytes = src[i : i + self.wlen[s0]] + try: + s = next(filter(sss.startswith, self.table[s0][s1])) + except: + pass + tokens.append(self.token2idx[s]) + i += len(s) + + return tokens + + def decodeBytes(self, tokens): + return b''.join(map(lambda i: self.idx2token[i], tokens)) + + def encode(self, src: str): + return self.encodeBytes(src.encode("utf-8")) + + def decode(self, tokens): + return self.decodeBytes(tokens).decode('utf-8') + + def printTokens(self, tokens): + for i in tokens: + s = self.idx2token[i] + try: + s = s.decode('utf-8') + except: + pass + print(f'{repr(s)}{i}', end=' ') + # print(repr(s), i) + print() + +tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") + +######################################################################################################## + +print(f'\nUsing NPU {str(DTYPE).replace("torch.","")}. Loading {args.MODEL_NAME} ...') +model = RWKV_x070(args) + +init_out, init_state = model.forward(tokenizer.encode(prompt), None) + +probs = F.softmax(init_out.float(), dim=-1) # compute softmax in float (more accurate) + +print(f'\n{prompt}') + +_, indices = torch.topk(probs, 10) # print top-10 possibilities +print(_) +for i in range(len(indices)): + token_id = indices[i].item() + token = tokenizer.decode([token_id]) + token_prob = probs[token_id].item() + print(token, f'[probability {token_prob:.2%}]') + +######################################################################################################## + +for TRIAL in range(NUM_TRIALS): + print(f'\n\n--[ Trial {TRIAL} ]-----------------', prompt, end="") + all_tokens = [] + out_last = 0 + out, state = init_out.clone(), copy.deepcopy(init_state) + + min_time = 1e10 + min_time_all = 1e10 + + t000 = time.perf_counter() + + for i in range(LENGTH_PER_TRIAL): + t00 = time.perf_counter() + token = sample_logits(out, TEMPERATURE, TOP_P) + all_tokens += [token] + try: + tmp = tokenizer.decode(all_tokens[out_last:]) + if '\ufffd' not in tmp: # only print when we have a valid utf-8 string + print(tmp, end="", flush=True) + out_last = i + 1 + except: + pass + t0 = time.perf_counter() + + out, state = model.forward(token, state) + + torch.npu.synchronize() + t1 = time.perf_counter() + min_time = min(min_time, t1 - t0) + min_time_all = min(min_time_all, t1 - t00) + + print(f'\n[ {round(1/min_time_all,2)} (real) / {round(1/min_time,2)} (ignore sampling & tokenizer) token/s = {round(time.perf_counter()-t000,3)}s ]', end='') + +print('\n') + +######################################################################################################## + +import json, math +with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: + todo = [json.loads(line) for line in f] + todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo] + +print('\nCheck LAMBADA...') +xsum = 0 +xcnt = 0 +xacc = 0 +for d in todo: + src = [0] + tokenizer.encode(d[0]) + dst = tokenizer.encode(d[1]) + + logits = 0 + correct = True + + out, _ = model.forward(src+dst, None, full_output=True) + + for i in range(len(dst)): + ooo = out[len(src)-1+i].float() + probs = F.softmax(ooo, dim=-1) + logits += math.log(probs[dst[i]]) + if torch.argmax(probs).item() != dst[i]: + correct = False + + xcnt += 1 + xsum += logits + xacc += 1 if correct else 0 + if xcnt % 100 == 0 or xcnt == len(todo): + print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2)) +