Skip to content
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
04c99b1
rename
chraac Jul 4, 2025
28d527e
Refactor vector operations in vec_op_impl and vec_dot_product_impl fo…
chraac Jul 4, 2025
ddf95af
wip
chraac Jul 4, 2025
f0d51d2
Enhance vector copy functions for improved performance and clarity in…
chraac Jul 4, 2025
814a8d4
wip
chraac Jul 4, 2025
41f3f64
wip
chraac Jul 4, 2025
ceb2fe2
wip
chraac Jul 4, 2025
889cb69
Optimize vector dot product implementations for enhanced performance …
chraac Jul 4, 2025
cec3fd8
Enhance flash attention implementation and type traits for improved v…
chraac Jul 4, 2025
311be57
remove align
chraac Jul 5, 2025
3eb8efc
wip
chraac Jul 5, 2025
04f1c2c
Enhance vector dot product implementation for improved performance by…
chraac Jul 5, 2025
661c916
Revert "Enhance vector dot product implementation for improved perfor…
chraac Jul 5, 2025
afb8ea5
Enhance flash attention implementation with type checks for tensor da…
chraac Jul 5, 2025
0e626a8
wip
chraac Jul 5, 2025
854bc23
opt mask calc
chraac Jul 5, 2025
709d752
Revert "opt mask calc"
chraac Jul 5, 2025
fb1614e
wip
chraac Jul 5, 2025
05decd9
opt mul mat caching logic to add dst cache
chraac Jul 6, 2025
9e3f759
Revert "opt mul mat caching logic to add dst cache"
chraac Jul 7, 2025
9643f21
wip
chraac Jul 7, 2025
7430cd3
Refactor matrix multiplication implementation to include vector conve…
chraac Jul 7, 2025
420f1f6
wip
chraac Jul 7, 2025
464ad02
wip
chraac Jul 7, 2025
8b763a9
wip
chraac Jul 8, 2025
25c6b3d
create vec_ops.inl for more aggressive compiler inline
chraac Jul 8, 2025
a86df9e
wip
chraac Jul 8, 2025
ec953fa
refactor vector dot product implementations for improved readability …
chraac Jul 8, 2025
f16492d
refactor vector conversion functions to use HVX_Vector_Dual for impro…
chraac Jul 8, 2025
06627fb
wip
chraac Jul 8, 2025
b3e3e7e
wip
chraac Jul 9, 2025
5090f8e
wip
chraac Jul 9, 2025
84e56c8
implement row size caching logic and enhance type traits for F32 support
chraac Jul 9, 2025
56ad5f8
refactor matrix multiplication functions to improve caching logic and…
chraac Jul 9, 2025
973ce41
add vector zeroing functions for F32 and F16 types to optimize memory…
chraac Jul 9, 2025
549c4fd
Revert "add vector zeroing functions for F32 and F16 types to optimiz…
chraac Jul 9, 2025
0652b72
wip
chraac Jul 10, 2025
40d0632
refactor alignment checks in dot product function to handle null poin…
chraac Jul 10, 2025
009e058
wip
chraac Jul 10, 2025
94eea19
refactor load_block_generic and related functions for improved alignm…
chraac Jul 11, 2025
3f6a487
wip
chraac Jul 11, 2025
dcf1580
refactor flash attention implementation and introduce type-erased dot…
chraac Jul 12, 2025
97a5678
refactor dot product implementations for improved loop handling and c…
chraac Jul 14, 2025
00cdd3f
refactor thread_pool constructor to pre-allocate VTCM cache for each …
chraac Jul 16, 2025
93fbaad
Revert "refactor thread_pool constructor to pre-allocate VTCM cache f…
chraac Jul 16, 2025
e0f795b
wip
chraac Jul 17, 2025
68cf1ca
opt interfaces for tensor cleanup
chraac Jul 17, 2025
b5d316c
refactor mul_mat_impl to use aligned size for src0 row calculation
chraac Jul 18, 2025
ef4550f
refactor: update dequantized_row_size logic and add size alignment ch…
chraac Jul 18, 2025
f7ab2db
wip
chraac Jul 18, 2025
bcb0e79
wip
chraac Jul 18, 2025
9298936
refactor: replace raw pointer initialization with invalid handle cons…
chraac Jul 18, 2025
a7ea145
wip
chraac Jul 18, 2025
74f28f5
Merge branch 'dev-refactoring' into dev-perf-opt-part5
chraac Jul 18, 2025
aa7d330
Merge branch 'dev-refactoring' into dev-perf-opt-part5
chraac Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions ggml/src/ggml-qnn/npu/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <hexagon_types.h>

#include <memory>
#include <new>

#include "graph.hpp"
#include "hexagon_npu.h"
Expand Down Expand Up @@ -69,20 +68,28 @@ struct npu_device_context {
}
};

inline hexagon::tensor * tensor_from_handle(npu_device_graph_handle_t h) {
inline hexagon::tensor * tensor_from_handle(npu_device_tensor_handle_t h) {
if (h == npu_device_INVALID_DEVICE_TENSOR_HANDLE) {
return nullptr;
}

return reinterpret_cast<hexagon::tensor *>(h);
}

inline npu_device_graph_handle_t tensor_to_handle(hexagon::tensor * tensor) {
return reinterpret_cast<npu_device_graph_handle_t>(tensor);
inline npu_device_tensor_handle_t tensor_to_handle(hexagon::tensor * tensor) {
return reinterpret_cast<npu_device_tensor_handle_t>(tensor);
}

inline hexagon::graph * graph_from_handle(npu_device_tensor_handle_t h) {
inline hexagon::graph * graph_from_handle(npu_device_graph_handle_t h) {
if (h == npu_device_INVALID_DEVICE_GRAPH_HANDLE) {
return nullptr;
}

return reinterpret_cast<hexagon::graph *>(h);
}

inline npu_device_tensor_handle_t graph_to_handle(hexagon::graph * graph) {
return reinterpret_cast<npu_device_tensor_handle_t>(graph);
inline npu_device_graph_handle_t graph_to_handle(hexagon::graph * graph) {
return reinterpret_cast<npu_device_graph_handle_t>(graph);
}

inline npu_device_context * device_context_from_handle(remote_handle64 h) {
Expand All @@ -93,12 +100,7 @@ inline npu_device_context * device_context_from_handle(remote_handle64 h) {

int npu_device_open(const char * uri, remote_handle64 * h) {
// TODO: should we have a device context here?
auto * context = new (std::nothrow) npu_device_context();
if (!context) {
DEVICE_LOG_ERROR("Failed to allocate memory for the npu_device_context");
return AEE_ENOMEMORY;
}

auto * context = new npu_device_context();
if (!context->init()) {
DEVICE_LOG_ERROR("Failed to initialize npu_device_context");
delete context;
Expand Down Expand Up @@ -144,12 +146,7 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
npu_device_tensor_handle_t * tensor_handle) {
NPU_UNUSED(_h);
auto * tensor = new (std::nothrow) hexagon::tensor(*info);
if (!tensor) {
DEVICE_LOG_ERROR("Failed to allocate memory for the tensor");
return AEE_ENOMEMORY;
}

auto * tensor = new hexagon::tensor(*info);
*tensor_handle = tensor_to_handle(tensor);
return AEE_SUCCESS;
}
Expand Down Expand Up @@ -177,13 +174,29 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
return AEE_SUCCESS;
}

AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
int tensor_handlesLen) {
NPU_UNUSED(_h);
auto * graph = new (std::nothrow) hexagon::graph();
if (!graph) {
return AEE_ENOMEMORY;
if (!tensor_handles || tensor_handlesLen < 0) {
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
return AEE_EINVARGS;
}

for (int i = 0; i < tensor_handlesLen; ++i) {
auto * tensor = tensor_from_handle(tensor_handles[i]);
if (tensor) {
delete tensor;
} else {
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d", i);
}
}

return AEE_SUCCESS;
}

AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
NPU_UNUSED(_h);
auto * graph = new hexagon::graph();
*graph_handle = graph_to_handle(graph);
return AEE_SUCCESS;
}
Expand Down
76 changes: 46 additions & 30 deletions ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
}

// From: ggml/src/ggml-cpu/ops.cpp
template <bool _IsKvF16>
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");

constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;

if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
return;
}

float scale = out->get_op_param<float>(0);
const float max_bias = out->get_op_param<float>(1);
const float logit_softcap = out->get_op_param<float>(2);
Expand All @@ -37,9 +46,11 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
if (!q_to_vec_dot || !kq_vec_dot) {
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
const auto q_to_vec_dot = k_type_traits.from_float;
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
if (!q_to_vec_dot) {
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
return;
}
Expand All @@ -50,12 +61,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const auto DK = k->get_ne(0);
const auto DV = v->get_ne(0);
const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size;
const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size;
const auto row_bytes_k = DK * k_type_traits.type_size;
const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size;

constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
constexpr const size_t kFloatsPerVectorPair = hexagon::kBytesPerVector * 2 / sizeof(float);
const auto aligned_dk = (DK + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
const auto aligned_dv = (DV + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv);
auto * cache_ptr = params->get_vtcm_cache(total_cache_size);
if (!cache_ptr) {
Expand All @@ -64,11 +75,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
}

// loop over n_batch and n_head
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
const bool is_v_f16 =
v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format
uint8_t * dst_ptr = out->get_write_buffer();
constexpr bool is_v_f16 = _IsKvF16; // check if V is in FP16 format, otherwise it is in FP32 format
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
hexagon::get_type_name(out->get_type()));
Expand All @@ -80,6 +90,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const uint8_t * k_ptr = k->get_read_buffer();
const uint8_t * v_ptr = v->get_read_buffer();
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) {
// q indices
const auto iq3 = ir / rows_per_batch;
Expand All @@ -90,15 +104,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const float slope =
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;

float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
hexagon::l2fetch_row(q_data, row_bytes_q);

if (is_v_f16) {
if constexpr (is_v_f16) {
memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t));
} else {
memset(VKQ32, 0, DV * sizeof(float));
Expand All @@ -117,16 +129,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
if (iq1 < q->get_ne(1) - 1) {
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
}

q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
const auto * k_plane_ptr = k_ptr + ik2 * k->get_nb(2) + ik3 * k->get_nb(3);
const auto * v_plane_ptr = v_ptr + iv2 * v->get_nb(2) + iv3 * v->get_nb(3);
for (int64_t ic = 0; ic < k->get_ne(1); ++ic) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop);
float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f;
Expand All @@ -137,7 +146,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
float s = 0.f;
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot);
const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3));
const auto * k_data = k_plane_ptr + ic * k->get_nb(1);
if (ic < k->get_ne(1) - 1) {
hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k);
}
Expand All @@ -156,12 +165,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)

const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3));
const auto * v_data = v_plane_ptr + ic * v->get_nb(1);
if (ic < v->get_ne(1)) {
hexagon::l2fetch_row(v_data, row_bytes_v);
}

if (is_v_f16) {
if constexpr (is_v_f16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
Expand Down Expand Up @@ -201,7 +210,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
S = S * ms + vs; // scale and increment sum with partial sum
}

if (is_v_f16) {
if constexpr (is_v_f16) {
// TODO: use a more efficient conversion
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = f16_to_f32(VKQ16[d]);
Expand All @@ -218,7 +227,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const int i3 = iq3;

// permute(0, 2, 1, 3)
memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1));
hexagon::vec_cpy_f32(
reinterpret_cast<const float *>(VKQ32),
reinterpret_cast<float *>(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1)),
out->get_ne(0));
}

out->release_write_buffer(); // mark the output tensor as modified
Expand All @@ -244,7 +256,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
return false;
}

flash_attn_impl(out, q, k, v, mask, params);
if (k->get_type() == NPU_DATA_TYPE_F16) {
flash_attn_impl<true>(out, q, k, v, mask, params);
} else {
flash_attn_impl<false>(out, q, k, v, mask, params);
}
return true;
}

Expand Down
69 changes: 8 additions & 61 deletions ggml/src/ggml-qnn/npu/device/op_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,10 @@

namespace {

template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector), typename _TyData>
inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);

HVX_Vector * iptr0 = ((HVX_Vector *) src0);
HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
HVX_Vector * iptr1 = ((HVX_Vector *) src1);
HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
HVX_Vector prev0 = *iptr0++;
HVX_Vector prev1 = *iptr1++;

while (iptr0 < iptr0_end) {
HVX_Vector curr0 = *iptr0++;
HVX_Vector curr1 = *iptr1++;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
*optr++ = _OpIntrinsic(s0, s1);
prev0 = curr0;
prev1 = curr1;
}

const size_t leftover = count % kElementsPerVector;
if ((iptr0_end - ((HVX_Vector *) src0)) > 0) {
// handle the last vector
// see also:
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0);
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1);
HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0;
HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1;
iptr0 += should_fetch_src0 ? 1 : 0;
iptr1 += should_fetch_src1 ? 1 : 0;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
*optr++ = _OpIntrinsic(s0, s1);
prev0 = curr0;
prev1 = curr1;
}

const size_t leftover_bytes = leftover * sizeof(_TyData);
if (leftover > 0) {
// handle the leftover elements
HVX_Vector curr0 =
(leftover_bytes + hexagon::unaligned_bytes(iptr0) > hexagon::kBytesPerVector) ? *iptr0 : prev0;
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);

HVX_Vector curr1 =
(leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1;
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);

hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
}
}

template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) {
vec_op_impl<_OpIntrinsic, float>(src0, src1, count, dst);
using namespace hexagon::vec;
vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst);
}

inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) {
Expand All @@ -84,10 +30,11 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) {
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b));
}

template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count,
npu_device_fp16_t * dst) {
vec_op_impl<_OpIntrinsic, npu_device_fp16_t>(src0, src1, count, dst);
using namespace hexagon::vec;
vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst);
}

inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
Expand Down Expand Up @@ -252,10 +199,10 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
prev = curr;
}

const size_t leftover_bytes = leftover * sizeof(float);
if (leftover > 0) {
// handle the leftover elements
HVX_Vector curr =
const size_t leftover_bytes = leftover * sizeof(float);
HVX_Vector curr =
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,
Expand Down
Loading