Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions hwy/contrib/dot/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,17 @@ struct Dot {
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;

// Won't be able to do a full vector load without padding => scalar loop.
// Won't be able to do a full vector load without padding. Use a scalar
// loop under Clang. GCC has very suboptimal codegen for scalar BF16->float
// conversions, so use vector ops with LoadN instead.
// TODO: https://github.com/google/highway/pull/2703
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < NF)) {
#if HWY_COMPILER_GCC_ACTUAL
const VF a = LoadN(df, pa, num_elements);
const VF b = PromoteTo(df, LoadN(dbfh, pb, num_elements));
return ReduceSum(df, Mul(a, b));
#else
// Only 2x unroll to avoid excessive code size.
float sum0 = 0.0f;
float sum1 = 0.0f;
Expand All @@ -189,6 +197,7 @@ struct Dot {
sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
}
return sum0 + sum1;
#endif
}

// Compiler doesn't make independent sum* accumulators, so unroll manually.
Expand Down Expand Up @@ -279,9 +288,19 @@ struct Dot {
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;

// Won't be able to do a full vector load without padding => scalar loop.
// Won't be able to do a full vector load without padding. Use a scalar
// loop under Clang. GCC has very suboptimal codegen for scalar BF16->float
// conversions, so use vector ops with LoadN instead.
// TODO: https://github.com/google/highway/pull/2703
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < N)) {
#if HWY_COMPILER_GCC_ACTUAL
const auto a = LoadN(d, pa, num_elements);
const auto b = LoadN(d, pb, num_elements);
V sum1 = Zero(df32);
V sum0 = ReorderWidenMulAccumulate(df32, a, b, Zero(df32), sum1);
return ReduceSum(df32, Add(sum0, sum1));
#else
float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
float sum1 = 0.0f; // this unlikely(?) case.
for (; i + 2 <= num_elements; i += 2) {
Expand All @@ -292,6 +311,7 @@ struct Dot {
sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
}
return sum0 + sum1;
#endif
}

// See comment in the other Compute() overload. Unroll 2x, but we need
Expand Down
Loading