From 0487c27cd60dc4a603cc40d715ba975b6827bae4 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 27 Nov 2024 06:07:32 +0000 Subject: [PATCH 1/6] layer_norm --- include/ops/layer_norm/layer_norm.h | 29 ++ operatorspy/tests/layer_norm.py | 156 ++++++++ src/ops/layer_norm/bang/layer_norm_bang.cc | 49 +++ src/ops/layer_norm/bang/layer_norm_bang.h | 34 ++ src/ops/layer_norm/bang/layer_norm_bang.mlu | 390 ++++++++++++++++++++ src/ops/layer_norm/cpu/layer_norm_cpu.cc | 125 +++++++ src/ops/layer_norm/cpu/layer_norm_cpu.h | 27 ++ src/ops/layer_norm/cuda/layer_norm.cc | 53 +++ src/ops/layer_norm/cuda/layer_norm.cu | 178 +++++++++ src/ops/layer_norm/cuda/layer_norm.cuh | 32 ++ src/ops/layer_norm/operator.cc | 86 +++++ src/ops/rms_norm/bang/rms_norm_cnnl.cc | 56 --- src/ops/rms_norm/bang/rms_norm_cnnl.h | 15 - src/ops/rms_norm/operator.cc | 1 - 14 files changed, 1159 insertions(+), 72 deletions(-) create mode 100644 include/ops/layer_norm/layer_norm.h create mode 100644 operatorspy/tests/layer_norm.py create mode 100644 src/ops/layer_norm/bang/layer_norm_bang.cc create mode 100644 src/ops/layer_norm/bang/layer_norm_bang.h create mode 100644 src/ops/layer_norm/bang/layer_norm_bang.mlu create mode 100644 src/ops/layer_norm/cpu/layer_norm_cpu.cc create mode 100644 src/ops/layer_norm/cpu/layer_norm_cpu.h create mode 100644 src/ops/layer_norm/cuda/layer_norm.cc create mode 100644 src/ops/layer_norm/cuda/layer_norm.cu create mode 100644 src/ops/layer_norm/cuda/layer_norm.cuh create mode 100644 src/ops/layer_norm/operator.cc delete mode 100644 src/ops/rms_norm/bang/rms_norm_cnnl.cc delete mode 100644 src/ops/rms_norm/bang/rms_norm_cnnl.h diff --git a/include/ops/layer_norm/layer_norm.h b/include/ops/layer_norm/layer_norm.h new file mode 100644 index 00000000..2b4bf0ee --- /dev/null +++ b/include/ops/layer_norm/layer_norm.h @@ -0,0 +1,29 @@ +#ifndef LAYER_NORM_H +#define LAYER_NORM_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct LayerNormDescriptor { + Device device; +} LayerNormDescriptor; + +typedef LayerNormDescriptor *infiniopLayerNormDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateLayerNormDescriptor( + infiniopHandle_t handle, + infiniopLayerNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + + + +__C __export infiniopStatus_t infiniopLayerNorm(infiniopLayerNormDescriptor_t desc, + void const *x, void const *w, void const *b, void *y, void *stream); + +__C __export infiniopStatus_t infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/layer_norm.py b/operatorspy/tests/layer_norm.py new file mode 100644 index 00000000..5c5253d3 --- /dev/null +++ b/operatorspy/tests/layer_norm.py @@ -0,0 +1,156 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float +import ctypes +import sys +import os + + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch +import torch.nn as nn + +class LayerNormDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopLayerNormDescriptor_t = POINTER(LayerNormDescriptor) + + +def LayerNormFunction(input, scale, bias, eps): + normlize_shape = scale.shape + layer_norm = nn.LayerNorm(normlize_shape, elementwise_affine=True, eps = eps) + layer_norm.weight.data = scale + layer_norm.bias.data = bias + return layer_norm.forward(input) + + +def test(lib, handle, torch_device, x_shape, axis, x_dtype=torch.float16): + print( + f"Testing Layernorm on {torch_device} with test_shape:{x_shape}, axis:{axis} ,dtype:{x_dtype}" + ) + eps = 1e-5 + ndim = len(x_shape) + normlize_shape = [] + for i in range(axis, ndim): + normlize_shape.append(x_shape[i]) + + x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + scale = torch.rand(normlize_shape, dtype=x_dtype).to(torch_device) + bias = torch.rand(normlize_shape, dtype=x_dtype).to(torch_device) + y = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + ans = LayerNormFunction(x, scale, bias, eps) + x_tensor = to_tensor(x, lib) + w_tensor = to_tensor(scale, lib) + b_tensor = to_tensor(bias, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopLayerNormDescriptor_t() + check_error( + lib.infiniopCreateLayerNormDescriptor( + handle, ctypes.byref(descriptor), x_tensor.descriptor, w_tensor.descriptor, b_tensor.descriptor, y_tensor.descriptor, eps + ) + ) + + check_error( + lib.infiniopLayerNorm( + descriptor, + x_tensor.data, + w_tensor.data, + b_tensor.data, + y_tensor.data, + None, + ) + ) + err = y.reshape(-1,1) - ans.reshape(-1,1) + print(max(abs(err))) + assert torch.allclose(y, ans, atol=0, rtol=1e-2) + check_error(lib.infiniopDestroyLayerNormDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "mlu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axis + + ((32, 20, 512), 0, torch.float16), + ((32, 20, 512), 1, torch.float16), + ((32, 20, 512), 2, torch.float16), + + ((32, 20, 512), 0, torch.float32), + ((32, 20, 512), 1, torch.float32), + ((32, 20, 512), 2, torch.float32), + + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateLayerNormDescriptor.restype = c_int32 + lib.infiniopCreateLayerNormDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopLayerNormDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + ] + + lib.infiniopLayerNorm.restype = c_int32 + lib.infiniopLayerNorm.argtypes = [ + infiniopLayerNormDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyLayerNormDescriptor.restype = c_int32 + lib.infiniopDestroyLayerNormDescriptor.argtypes = [ + infiniopLayerNormDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("Test passed!") diff --git a/src/ops/layer_norm/bang/layer_norm_bang.cc b/src/ops/layer_norm/bang/layer_norm_bang.cc new file mode 100644 index 00000000..b0fc8d78 --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.cc @@ -0,0 +1,49 @@ +#include "layer_norm_bang.h" +#include "../../utils.h" +infiniopStatus_t bangCreateLayerNormDescriptor(BangHandle_t handle, LayerNormBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + *desc_ptr = new LayerNormBangDescriptor{ + handle->device, + handle->device_id, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t bangDestroyLayerNormDescriptor(LayerNormBangDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/layer_norm/bang/layer_norm_bang.h b/src/ops/layer_norm/bang/layer_norm_bang.h new file mode 100644 index 00000000..a630b39d --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.h @@ -0,0 +1,34 @@ +#ifndef __BANG_LAYER_NORM_H__ +#define __BANG_LAYER_NORM_H__ + +#include "../../../devices/bang/bang_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct LayerNormBangDescriptor { + Device device; + int device_id; + DT dtype; + int size; + int behindsize; + float epsilon; +}; + +typedef struct LayerNormBangDescriptor *LayerNormBangDescriptor_t; + +infiniopStatus_t bangCreateLayerNormDescriptor(BangHandle_t handle, + LayerNormBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + + +infiniopStatus_t bangLayerNorm(LayerNormBangDescriptor_t desc, + void const *x, void const *w, void const *b, void *y, + void *stream); + +infiniopStatus_t bangDestroyLayerNormDescriptor(LayerNormBangDescriptor_t desc); + +#endif// __BANG_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/bang/layer_norm_bang.mlu b/src/ops/layer_norm/bang/layer_norm_bang.mlu new file mode 100644 index 00000000..69f16f9b --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.mlu @@ -0,0 +1,390 @@ +#include "bang.h" +#include "cnrt.h" +#include "layer_norm_bang.h" +#include "../../../devices/bang/common_bang.h" + +const int SRC_MAX_SIZE = 1024 * 16; +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void layer_norm(T const *input, T const *scale, T const *bias, T *output, T *tmpGdram, float eps, int size, int behindsize, int bSize){ + int frontsize = size / behindsize; + const int wSize = 128 / sizeof(T); + + const int maxNum = SRC_MAX_SIZE / sizeof(T); + + + T *src = (T *)nram_buffer;//[maxNum] + T *destSum = src + 3 * maxNum;//[3 * maxNum] + T *destSumFinal = destSum + maxNum;//[wSize] + T *s_src = destSumFinal + wSize;//[3 * maxNum] + T *b_src = s_src + 3 * maxNum;//[3 * maxNum] + //bSize是大于等于behindsize的最小2次幂 + + if (behindsize >= taskDim * maxNum){ + int segNum = maxNum / wSize; + int taskSize = taskDim * maxNum; + int remain = behindsize % taskSize; + int repeat = (behindsize - remain) / taskSize; + + int remainT = remain % taskDim; + int stepEasy = (remain - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int indStart = repeat * taskSize + (taskId < remainT ? taskId * stepHard : (remainT * stepHard + (taskId - remainT) * stepEasy)); + for(int i = 0; i < frontsize; i++){ + int tid = i * behindsize; + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if(step){ + __bang_write_zero(src, maxNum); + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_add(destSum, destSum, src, maxNum); + } + __bang_mul_scalar(destSum, destSum, 1.0 / behindsize, maxNum); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0]存储的是当前task对应数据的规约和 + tmpGdram[taskId] = destSumFinal[0]; + __sync_all(); + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __memcpy(destSum, tmpGdram, taskDim * sizeof(T), GDRAM2NRAM); + __bang_reduce_sum(destSumFinal, destSum, wSize); + T mu = destSumFinal[0]; + //下面计算方差 + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if (j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_sub_scalar(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, mu, maxNum); + __bang_mul(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, maxNum); + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (step){ + __bang_write_value(src, maxNum, mu);//保证后面减去均值为0 + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul(src, src, src, maxNum); + __bang_add(destSum, destSum, src, maxNum); + } + __bang_mul_scalar(destSum, destSum, 1.0 / behindsize, maxNum); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0]存储的是当前task对应数据的规约和 + + tmpGdram[taskId] = destSumFinal[0]; + __sync_all(); + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __memcpy(destSum, tmpGdram, taskDim * sizeof(T), GDRAM2NRAM); + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = destSumFinal[0] + static_cast(eps); + sigma2 = 1.0 / pow(sigma2, 0.5); + //下面开始做变换 + for(int j = 0; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(s_src + j % 3 * maxNum, scale + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(b_src + j % 3 * maxNum, bias + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0 && j < repeat + 1){ + __bang_sub_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, mu, maxNum); + __bang_mul_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, sigma2, maxNum); + __bang_mul(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, s_src + (j - 1) % 3 * maxNum, maxNum); + __bang_add(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, b_src + (j - 1) % 3 * maxNum, maxNum); + } + if(j > 1){ + __memcpy_async(output + tid + (j - 2) * taskSize + taskId * maxNum, src + (j - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if (step){ + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __memcpy(s_src, scale + indStart, step * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul_scalar(src, src, sigma2, maxNum); + __bang_mul(src, src, s_src, maxNum); + __bang_add(src, src, b_src, maxNum); + __memcpy(output + tid + indStart, src, step * sizeof(T), NRAM2GDRAM); + } + } + } + else if(behindsize >= maxNum && behindsize < taskDim * maxNum){ + int segNum = maxNum / wSize; + int remainT = behindsize % maxNum; + int repeat = (behindsize - remainT) / maxNum; + + int remain = frontsize % taskDim; + int stepEasy = (frontsize - remain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy); + int indStart = (taskId < remain ? taskId * stepHard : (remain * stepHard + (taskId - remain) * stepEasy)); + for(int i = indStart; i < indStart + step; i++){ + int tid = i * behindsize; + //下面开始计算均值 + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if (j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (remainT){ + __bang_write_zero(src, maxNum); + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_add(destSum, destSum, src, maxNum); + } + + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + //下面开始计算方差,destSumFinal[0]存储的就是均值 + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_sub_scalar(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, mu, maxNum); + __bang_mul(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, maxNum); + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (remainT){ + __bang_write_value(src, maxNum, mu);//保证后面减去均值为0 + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul(src, src, src, maxNum); + __bang_add(destSum, destSum, src, maxNum); + } + + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = destSumFinal[0] / behindsize + static_cast(eps); + sigma2 = 1.0 / pow(sigma2, 0.5); + //下面开始做变换 + for(int j = 0; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(s_src + j % 3 * maxNum, scale + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(b_src + j % 3 * maxNum, bias + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0 && j < repeat + 1){ + __bang_sub_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, mu, maxNum); + __bang_mul_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, sigma2, maxNum); + __bang_mul(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, s_src + (j - 1) % 3 * maxNum, maxNum); + __bang_add(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, b_src + (j - 1) % 3 * maxNum, maxNum); + } + if(j > 1){ + __memcpy_async(output + tid + (j - 2) * maxNum, src + (j - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if(remainT){ + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __memcpy(s_src, scale + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul_scalar(src, src, sigma2, maxNum); + __bang_mul(src, src, s_src, maxNum); + __bang_add(src, src, b_src, maxNum); + __memcpy(output + tid + repeat * maxNum, src, remainT * sizeof(T), NRAM2GDRAM); + } + } + } + else{ + int multiple = maxNum / behindsize;//一个core一次可以处理multiple个behindsize + int taskSize = taskDim * multiple; + int remainT = frontsize % taskSize; + int repeat = (frontsize - remainT) / taskSize; + int remain = remainT % taskDim; + int stepEasy = (remainT - remain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy); + int indStart = (taskId < remain ? taskId * stepHard : (remain * stepHard + (taskId - remain) * stepEasy)); + int segNum = bSize / wSize; + __memcpy(s_src, scale, behindsize * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias, behindsize * sizeof(T), GDRAM2NRAM); + int tid; + for(int i = 0; i < repeat + 2; i++){ + if(i < repeat){ + tid = i * taskSize * behindsize; + __memcpy_async(src + i % 3 * maxNum, input + tid + taskId * multiple * behindsize, multiple * behindsize * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < repeat + 1){ + for(int m = 0; m < multiple; m++){ + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __bang_add(destSum, destSum, src + (i - 1) % 3 * maxNum + m *behindsize, behindsize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0] / behindsize = mu + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_sub_scalar(destSum, src + (i - 1) % 3 * maxNum + m * behindsize, mu, behindsize); + + __bang_mul(destSum, destSum, destSum, bSize); + __bang_write_zero(destSumFinal, wSize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = 1.0 / (pow(destSumFinal[0] / behindsize + static_cast(eps), 0.5)); + //下面开始做变换 + __bang_sub_scalar(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, mu, behindsize); + __bang_mul_scalar(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, sigma2, behindsize); + __bang_mul(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, s_src, behindsize); + __bang_add(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, b_src, behindsize); + } + } + if(i > 1){ + tid = (i - 2) * taskSize * behindsize; + __memcpy_async(output + tid + taskId * multiple * behindsize, src + (i - 2) % 3 * maxNum, multiple * behindsize * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if(step){ + int tid = (repeat * taskSize + indStart) * behindsize; + __memcpy(src, input + tid, step * behindsize * sizeof(T), GDRAM2NRAM); + for(int m = 0; m < step; m++){ + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __bang_add(destSum, destSum, src + m *behindsize, behindsize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0] / behindsize = mu + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_sub_scalar(destSum, src + m * behindsize, mu, behindsize); + + __bang_mul(destSum, destSum, destSum, bSize); + __bang_write_zero(destSumFinal, wSize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = 1.0 / (pow(destSumFinal[0] / behindsize + static_cast(eps), 0.5)); + //下面开始做变换 + __bang_sub_scalar(src + m * behindsize, src + m * behindsize, mu, behindsize); + __bang_mul_scalar(src + m * behindsize, src + m * behindsize, sigma2, behindsize); + __bang_mul(src + m * behindsize, src + m * behindsize, s_src, behindsize); + __bang_add(src + m * behindsize, src + m * behindsize, b_src, behindsize); + + } + __memcpy(output + tid, src, step * behindsize * sizeof(T), NRAM2GDRAM); + } + } +} +template +void layer_normUnion(cnrtQueue_t queue, void const *input, void const *scale, void const *bias, void *output, float eps, int size, int behindsize){ + int wSize = 128 / sizeof(T); + int bSize; + float mi = log2(behindsize); + if (floor(mi) == mi) + { + bSize = behindsize; + } + else + { + bSize = static_cast(pow(2, floor(mi) + 1)); + } + if (bSize < wSize) + { + bSize = wSize; + } + auto source = reinterpret_cast(input); + auto weight = reinterpret_cast(scale); + auto _bias = reinterpret_cast(bias); + auto destination = reinterpret_cast(output); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + int taskNum = k_dim.x * k_dim.y * k_dim.z; + + k_type = CNRT_FUNC_TYPE_UNION1; + T *tmpGdram; + CNRT_CHECK(cnrtMalloc((void **)&tmpGdram, taskNum * sizeof(T))); + layer_norm<<>>(source, weight, _bias, destination, tmpGdram, eps, size, behindsize, bSize); + cnrtFree(tmpGdram); + cnrtQueueSync(queue); +} +void layer_norm_bang(LayerNormBangDescriptor_t desc, void const *x, void const *w, void const *b, void *y, + void *stream){ + auto queue = reinterpret_cast(stream); + auto eps = desc->epsilon;//float + int size = desc->size; + int behindsize = desc->behindsize; + if (dtype_eq(desc->dtype, F16)){ + layer_normUnion(queue, x, w, b, y, eps, size, behindsize); + } + else if (dtype_eq(desc->dtype, F32)){ + layer_normUnion(queue, x, w, b, y, eps, size, behindsize); + } +} +infiniopStatus_t bangLayerNorm(LayerNormBangDescriptor_t desc, + void const *x, + void const *w, + void const *b, + void *y, + void *stream) { + if (cnrtSetDevice(desc->device_id) != cnrtSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + layer_norm_bang(desc, x, w, b, y, stream); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/layer_norm/cpu/layer_norm_cpu.cc b/src/ops/layer_norm/cpu/layer_norm_cpu.cc new file mode 100644 index 00000000..614e5164 --- /dev/null +++ b/src/ops/layer_norm/cpu/layer_norm_cpu.cc @@ -0,0 +1,125 @@ +#include "layer_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" +#include + +infiniopStatus_t cpuCreateLayerNormDescriptor(infiniopHandle_t handle, LayerNormCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + + *desc_ptr = new LayerNormCpuDescriptor{ + handle->device, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyLayerNormDescriptor(LayerNormCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} + +void layer_norm_cpu(LayerNormCpuDescriptor_t desc, void const *x, void const *w, void const *b, void *y) { + int size = desc->size; + int behindsize = desc->behindsize; + int frontsize = size / behindsize; + float eps = desc->epsilon; + if (dtype_eq(desc->dtype, F32)) + { + auto source = reinterpret_cast(x); + auto weight = reinterpret_cast(w); + auto _bias = reinterpret_cast(b); + auto destination = reinterpret_cast(y); + for (int i = 0; i < frontsize; i++) + { + int tid = i * behindsize; + float mu = 0.0f; + for (int id = 0; id < behindsize; id++) + { + mu += source[tid + id]; + } + mu /= behindsize; + float sigma2Partial = 0.0f; + for (int id = 0; id < behindsize; id++) + { + sigma2Partial += (source[tid + id] - mu) * (source[tid + id] - mu); + } + float sigma2 = 1.0f / sqrt(sigma2Partial / behindsize + eps); + for (int id = 0; id < behindsize; id++) + { + destination[tid + id] = (source[tid + id] - mu) * weight[id] * sigma2 + _bias[id]; + } + } + } + else if (dtype_eq(desc->dtype, F16)) + { + auto source = reinterpret_cast(x); + auto weight = reinterpret_cast(w); + auto _bias = reinterpret_cast(b); + auto destination = reinterpret_cast(y); + for (int i = 0; i < frontsize; i++) + { + int tid = i * behindsize; + float mu = 0.0f; + for (int id = 0; id < behindsize; id++) + { + mu += f16_to_f32(source[tid + id]); + } + mu /= behindsize; + float sigma2Partial = 0.0f; + for (int id = 0; id < behindsize; id++) + { + sigma2Partial += (f16_to_f32(source[tid + id]) - mu) * (f16_to_f32(source[tid + id]) - mu); + } + float sigma2 = 1.0f / sqrt(sigma2Partial / behindsize + eps); + for (int id = 0; id < behindsize; id++) + { + float tmp = (f16_to_f32(source[tid + id]) - mu) * f16_to_f32(weight[id]) * sigma2 + f16_to_f32(_bias[id]); + destination[tid + id] = f32_to_f16(tmp); + } + } + } +} + +infiniopStatus_t cpuLayerNorm(LayerNormCpuDescriptor_t desc, + void const *x, void const *w, void const *b, void *y, + void *stream) { + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + layer_norm_cpu(desc, x, w, b, y); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/layer_norm/cpu/layer_norm_cpu.h b/src/ops/layer_norm/cpu/layer_norm_cpu.h new file mode 100644 index 00000000..dd034f56 --- /dev/null +++ b/src/ops/layer_norm/cpu/layer_norm_cpu.h @@ -0,0 +1,27 @@ +#ifndef __CPU_LAYER_NORM_H__ +#define __CPU_LAYER_NORM_H__ + +#include "operators.h" + +struct LayerNormCpuDescriptor { + Device device; + DT dtype; + int size; + int behindsize; + float epsilon; +}; + +typedef struct LayerNormCpuDescriptor *LayerNormCpuDescriptor_t; + +infiniopStatus_t cpuCreateLayerNormDescriptor(infiniopHandle_t handle, LayerNormCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); +infiniopStatus_t cpuLayerNorm(LayerNormCpuDescriptor_t desc, + void const *x, void const *w, void const *b, void *y, + void *stream); +infiniopStatus_t cpuDestroyLayerNormDescriptor(LayerNormCpuDescriptor_t desc); + +#endif// __CPU_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/cuda/layer_norm.cc b/src/ops/layer_norm/cuda/layer_norm.cc new file mode 100644 index 00000000..134f8fcb --- /dev/null +++ b/src/ops/layer_norm/cuda/layer_norm.cc @@ -0,0 +1,53 @@ +#include "layer_norm.cuh" +#include "../../utils.h" +#include "../../../devices/cuda/common_cuda.h" + +infiniopStatus_t cudaCreateLayerNormDescriptor(CudaHandle_t handle, + LayerNormCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + *desc_ptr = new LayerNormCudaDescriptor{ + handle->device, + handle->device_id, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} + + +infiniopStatus_t cudaDestroyLayerNormDescriptor(LayerNormCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/layer_norm/cuda/layer_norm.cu b/src/ops/layer_norm/cuda/layer_norm.cu new file mode 100644 index 00000000..11e21338 --- /dev/null +++ b/src/ops/layer_norm/cuda/layer_norm.cu @@ -0,0 +1,178 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "layer_norm.cuh" +#include + +template +__launch_bounds__(BLOCK_DIM) + __global__ void blockLayernormKernel(T const *input, T const *scale, T const *bias, T *output, float eps, int behindsize) +{ + // 假设input= [A, B, C, D], axis = 2, frontsize = AB = blockDim.x, behindsize = CD + // 全局索引index = i(BCD) + j (CD) + k(D) + s + // blockIdx.x = i(B) + j;默认behindsize >= BLOCK_DIM + // scale,bias长度为behindsize,形状为[C,D] + int tid = blockIdx.x * behindsize; + float muPartial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) + { + muPartial += static_cast(input[tid + id]); // half很多操作不支持,运算过程使用float数据 + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float mu; + float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + if (threadIdx.x == 0) + { + mu = muBlock * __fdividef(1.0F, behindsize); + } // threadIdx.x = 0对应的是全局sum + __syncthreads(); + float sigma2Partial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) + { + sigma2Partial += (static_cast(input[tid + id]) - mu) * (static_cast(input[tid + id]) - mu); + } + __shared__ float sigma2; + float sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + if (threadIdx.x == 0) + { + float sigmaTmp = sqrt(sigma2Block * __fdividef(1.0F, behindsize) + eps); + sigma2 = __fdividef(1.0F, sigmaTmp); + } + __syncthreads(); + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) + { + output[tid + id] = static_cast(static_cast(scale[id]) * (static_cast(input[tid + id]) - mu) * sigma2 + static_cast(bias[id])); + } +} +template +struct SumOp +{ + __device__ __forceinline__ T operator()(const T &a, const T &b) const + { + return a + b; + } +}; + +template