diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 8f65904b8fe51..bc33b99d96ea6 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); + if(acl_dst == nullptr) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src); + } else { + GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); + } } void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); + if(acl_dst == nullptr) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src); + } else { + GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); + } } void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, @@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, ggml_cann_release_resources(ctx, acl_index, acl_value); } +/** + * @brief Initializes and caches sine/cosine positional encoding values + * (used in RoPE, Rotary Position Embedding) for attention layers. + * + * This function computes and caches the sin/cos values of + * θ = position * theta_scale for RoPE encoding. The cache is shared + * across attention layers, and only the first attention layer will + * trigger initialization. The cache includes repeated sin/cos values + * with different repeat methods depending on the @param is_neox flag. + * + * Steps performed by this function: + * 1. Identify whether the target tensor belongs to Q/K in attention + * and restrict computation to the first layer only. + * 2. Initialize the theta scale array (arange → power → freq scaling). + * 3. Allocate sin/cos caches if the max prompt length increases. + * 4. Compute θ = position * theta_scale. + * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. + * 6. Expand sin/cos values by repeat or repeat_interleave depending + * on whether @param is_neox is enabled. + * 7. Store the computed values into persistent buffers + * (ctx.rope_sin_ptr / ctx.rope_cos_ptr). + * + * @param ctx The CANN backend context, holding memory pool, + * stream, and persistent buffers for rope init/cache. + * @param dst The destination ggml_tensor whose computation + * depends on the cached RoPE values (usually Qcur/Kcur). + * @param theta_scale Scalar exponent base for computing theta scale values. + * @param freq_scale Frequency scaling factor, applied to theta scale. + * @param attn_factor Attention scaling factor, applied to sin/cos. + * @param is_neox Whether to use Neox-style repeat strategy + * (dim expansion vs repeat_interleave). + */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, - aclTensor* acl_cos_repeat_tensor, - aclTensor* acl_sin_repeat_tensor, float theta_scale, float freq_scale, float attn_factor, bool is_neox) { // int sin/cos cache, cache has different repeat method depond on // @param.is_neox + bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0); + bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0); + + // used for accuracy testing + bool is_attention = is_q || is_k; + + // just compute in first layer in attention + bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0); + if(is_attention && !is_fisrt_layer) { + return; + } ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position @@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; } - bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0); - bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0); - - // used for accuracy testing - bool is_attention = is_q || is_k; - - if(ctx.init_ptr == nullptr || !is_attention) { + // init theta scale, just one time + if(ctx.rope_init_ptr == nullptr || !is_attention) { // theta_scale arange, [0,1,...,ne00/2 - 1] - if(ctx.init_ptr != nullptr){ - ACL_CHECK(aclrtFree(ctx.init_ptr)); + if(ctx.rope_init_ptr != nullptr){ + ACL_CHECK(aclrtFree(ctx.rope_init_ptr)); } - ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); aclTensor* acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; @@ -2297,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale); } - if(ctx.sin_ptr == nullptr) { - int64_t theta_length = theta_scale_length * ctx.max_prompt_length; - ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - } + // init sin_repeat && cos_repeat, one token just init in 0 layer if(position_length > ctx.max_prompt_length) { ctx.max_prompt_length = position_length; - int64_t theta_length = theta_scale_length * ctx.max_prompt_length; - ACL_CHECK(aclrtFree(ctx.sin_ptr)); - ACL_CHECK(aclrtFree(ctx.cos_ptr)); - ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2; + if(ctx.rope_sin_ptr != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_sin_ptr)); + ACL_CHECK(aclrtFree(ctx.rope_cos_ptr)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); } - bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0); - - if(is_fisrt_layer || !is_attention) { - - aclTensor* acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t), + aclTensor* acl_theta_scale_tensor = + ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - // position - aclTensor* acl_position_tensor = ggml_cann_create_tensor( - src1->data, ggml_cann_type_mapping(src1->type), - ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); - - // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), - theta_length * sizeof(float_t)); - void* theta_buffer = theta_allocator.get(); - - aclTensor* acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t), - theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, - acl_theta_tensor); - - // sin/cos - aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); - - aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); - - // release - ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, - acl_theta_tensor, acl_sin_tensor, acl_cos_tensor); - } - + // position + aclTensor* acl_position_tensor = ggml_cann_create_tensor( + src1->data, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); + + // power * position + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), + theta_length * sizeof(float_t)); + void* theta_buffer = theta_allocator.get(); + + aclTensor* acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t), + theta_ne, theta_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, + acl_theta_tensor); + + // sin/cos + ggml_cann_pool_alloc sin_allocator(ctx.pool(), + theta_length * sizeof(float_t)); + void* sin_buffer = sin_allocator.get(); aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); + sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, + GGML_MAX_DIMS, ACL_FORMAT_ND); + aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); + + ggml_cann_pool_alloc cos_allocator(ctx.pool(), + theta_length * sizeof(float_t)); + void* cos_buffer = cos_allocator.get(); aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); + cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, + GGML_MAX_DIMS, ACL_FORMAT_ND); + aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); // attn_factor if (attn_factor != 1) { @@ -2365,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); } + int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; + size_t sin_reshape_nb[GGML_MAX_DIMS]; + sin_reshape_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; + } + aclTensor* acl_sin_repeat_tensor = + ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor* acl_cos_repeat_tensor = + ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + // repeat if (is_neox) { int64_t repeatsArray[] = {1, 1, 1, 2}; @@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, num_repeats, output_size); } - // release - ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor); + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, + acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, + acl_cos_repeat_tensor); } #ifdef __cplusplus @@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - // init cos/sin cache - ggml_cann_pool_alloc sin_allocator( - ctx.pool(), ne00 * ne02 * sizeof(float_t)); - ggml_cann_pool_alloc cos_allocator( - ctx.pool(), ne00 * ne02 * sizeof(float_t)); - void* sin_buffer = sin_allocator.get(); - void* cos_buffer = cos_allocator.get(); + // init ctx.rope_cos/rope_sin cache + aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, freq_scale, attn_factor, is_neox); aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 5858bd3f6a197..33794062f565d 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -368,10 +368,6 @@ struct ggml_backend_cann_context { std::string name; /**< Name of the device. */ std::string description; /**< Description of the device. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ - void* init_ptr = nullptr; - void* sin_ptr = nullptr; - void* cos_ptr = nullptr; - int64_t max_prompt_length = 65536; #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. std::unique_ptr cann_graph; @@ -379,6 +375,12 @@ struct ggml_backend_cann_context { cann_task_queue task_queue; bool async_mode; bool support_set_rows; + // Rope Cache + void* rope_init_ptr = nullptr; + void* rope_sin_ptr = nullptr; + void* rope_cos_ptr = nullptr; + int64_t max_prompt_length = 0; + // Constant Pool void* f32_zero_cache = nullptr; void* f32_one_cache = nullptr; int64_t f32_zero_cache_element = 0; @@ -422,14 +424,20 @@ struct ggml_backend_cann_context { ACL_CHECK(aclrtDestroyStream(streams[i])); } } - if(init_ptr != nullptr) { - ACL_CHECK(aclrtFree(init_ptr)); + if(rope_init_ptr != nullptr) { + ACL_CHECK(aclrtFree(rope_init_ptr)); } - if(sin_ptr != nullptr) { - ACL_CHECK(aclrtFree(sin_ptr)); + if(rope_sin_ptr != nullptr) { + ACL_CHECK(aclrtFree(rope_sin_ptr)); } - if(cos_ptr != nullptr) { - ACL_CHECK(aclrtFree(cos_ptr)); + if(rope_cos_ptr != nullptr) { + ACL_CHECK(aclrtFree(rope_cos_ptr)); + } + if(f32_zero_cache != nullptr) { + ACL_CHECK(aclrtFree(f32_zero_cache)); + } + if(f32_one_cache != nullptr) { + ACL_CHECK(aclrtFree(f32_one_cache)); } }