From dc762edfa02a259e1ab16734f938e50841b8a070 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 20 Feb 2026 12:36:00 +0100 Subject: [PATCH 01/15] Simplifications for emitted methods. --- mechanisms/allen/NaV.mod | 34 +++++++----------- modcc/msparse.hpp | 52 +++++++++++---------------- modcc/symdiff.cpp | 77 +++++++++++++++++++++++++++++++++++++++- modcc/symge.cpp | 8 ++--- modcc/symge.hpp | 18 +++------- 5 files changed, 116 insertions(+), 73 deletions(-) diff --git a/mechanisms/allen/NaV.mod b/mechanisms/allen/NaV.mod index e40cbee8b1..7c6df9c847 100644 --- a/mechanisms/allen/NaV.mod +++ b/mechanisms/allen/NaV.mod @@ -55,7 +55,9 @@ STATE { BREAKPOINT { SOLVE activation METHOD sparse - ina = gbar*O*(v - ena) + LOCAL g + g = gbar * O + ina = g*(v - ena) } INITIAL { @@ -64,20 +66,14 @@ INITIAL { } KINETIC activation { - LOCAL f01, f02, f03, f04, f0O, f11, f12, f13, f14, fi1, fi2, fi3, fi4, fi5, fin, b01, b02, b03, b04, b0O, b11, b12, b13, b14, bi1, bi2, bi3, bi4, bi5, bin, ibtf + LOCAL f04, f0O, f14, fi1, fi2, fi3, fi4, fi5, fin, b01, b0O, b11, bi1, bi2, bi3, bi4, bi5, bin, ibtf ibtf = 1/btfac f04 = qt*alpha*exp(v/x1) - f03 = 2*f04 - f02 = 3*f04 - f01 = 4*f04 f0O = qt*gamma f14 = alfac*f04 - f13 = 2*f14 - f12 = 3*f14 - f11 = 4*f14 fi1 = qt*Con fi2 = fi1*alfac @@ -87,15 +83,9 @@ KINETIC activation { fin = qt*Oon b01 = qt*beta*exp(v/x2) - b02 = 2*b01 - b03 = 3*b01 - b04 = 4*b01 b0O = qt*delta b11 = b01*ibtf - b12 = 2*b11 - b13 = 3*b11 - b14 = 4*b11 bi1 = qt*Coff bi2 = bi1*ibtf @@ -104,16 +94,16 @@ KINETIC activation { bi5 = bi4*ibtf bin = qt*Ooff - ~ C1 <-> C2 (f01, b01) - ~ C2 <-> C3 (f02, b02) - ~ C3 <-> C4 (f03, b03) - ~ C4 <-> C5 (f04, b04) + ~ C1 <-> C2 (4*f04, 1*b01) + ~ C2 <-> C3 (3*f04, 2*b01) + ~ C3 <-> C4 (2*f04, 3*b01) + ~ C4 <-> C5 (1*f04, 4*b01) ~ C5 <-> O (f0O, b0O) ~ O <-> I6 (fin, bin) - ~ I1 <-> I2 (f11, b11) - ~ I2 <-> I3 (f12, b12) - ~ I3 <-> I4 (f13, b13) - ~ I4 <-> I5 (f14, b14) + ~ I1 <-> I2 (4*f14, 1*b11) + ~ I2 <-> I3 (3*f14, 2*b11) + ~ I3 <-> I4 (2*f14, 3*b11) + ~ I4 <-> I5 (1*f14, 4*b11) ~ I5 <-> I6 (f0O, b0O) ~ C1 <-> I1 (fi1, bi1) ~ C2 <-> I2 (fi2, bi2) diff --git a/modcc/msparse.hpp b/modcc/msparse.hpp index 34e6866826..6f3babad45 100644 --- a/modcc/msparse.hpp +++ b/modcc/msparse.hpp @@ -1,9 +1,7 @@ #pragma once // (Possibly augmented) matrix implementation, represented as a vector of sparse rows. - #include -#include #include #include #include @@ -47,14 +45,12 @@ class row { row(const row&) = default; row(std::initializer_list il): data_(il) { - if (!check_invariant()) - throw msparse_error("improper row element list"); + if (!check_invariant()) throw msparse_error("improper row element list"); } template row(InIter b, InIter e): data_(b, e) { - if (!check_invariant()) - throw msparse_error("improper row element list"); + if (!check_invariant()) throw msparse_error("improper row element list"); } unsigned size() const { return data_.size(); } @@ -67,22 +63,18 @@ class row { auto end() const -> decltype(data_.cend()) { return data_.cend(); } // Return column of first (left-most) entry. - unsigned mincol() const { - return empty()? npos: data_.front().col; - } + unsigned mincol() const { return empty()? npos: data_.front().col; } // Return column of first entry with column greater than `c`. unsigned mincol_after(unsigned c) const { - auto i = std::upper_bound(data_.begin(), data_.end(), c, - [](unsigned a, const entry& b) { return acol; } // Return column of last (right-most) entry. - unsigned maxcol() const { - return empty()? npos: data_.back().col; - } + unsigned maxcol() const { return empty()? npos: data_.back().col; } // As opposed to [] indexing (see below), retrieve `i'th entry from // the list of entries. @@ -91,27 +83,25 @@ class row { } void push_back(const entry& e) { - if (!empty() && e.col <= data_.back().col) - throw msparse_error("cannot push_back row elements out of order"); + if (!empty() && e.col <= data_.back().col) throw msparse_error("cannot push_back row elements out of order"); data_.push_back(e); } - void clear() { - data_.clear(); - } + void clear() { data_.clear();} // Return index into entry list which has column `c`. unsigned index(unsigned c) const { - auto i = std::lower_bound(data_.begin(), data_.end(), c, - [](const entry& a, unsigned b) { return a.colcol!=c)? npos: std::distance(data_.begin(), i); } // Remove all entries from column `c` onwards. void truncate(unsigned c) { - auto i = std::lower_bound(data_.begin(), data_.end(), c, - [](const entry& a, unsigned b) { return a.col& r, unsigned c): row_(r), c(c) {} operator X() const { return const_cast&>(row_)[c]; } + assign_proxy& operator=(const X& x) { - auto i = std::lower_bound(row_.data_.begin(), row_.data_.end(), c, - [](const entry& a, unsigned b) { return a.colcol!=c) { row_.data_.insert(i, {c, x}); } @@ -143,14 +134,11 @@ class row { else { i->value = x; } - return *this; } }; - assign_proxy operator[](unsigned c) { - return assign_proxy{*this, c}; - } + assign_proxy operator[](unsigned c) { return assign_proxy{*this, c}; } }; // `msparse::matrix` represents a matrix by a size (number of rows, diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index ea7a7e7b4d..1201e49bf7 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -481,6 +481,14 @@ class ConstantSimplifyVisitor: public Visitor { expression_ptr arg = result(); result_ = e->clone(); result_->is_unary()->replace_expression(std::move(arg)); + + // fold nested + if (auto res = result_->is_unary(); res && res->op() == tok::minus) { + if (auto arg = res->expression()->is_unary(); arg && arg->op() == tok::minus) { + result_ = arg->expression()->clone(); + } + } + } void visit(BinaryExpression* e) override { @@ -512,6 +520,34 @@ class ConstantSimplifyVisitor: public Visitor { else if (expr_value(rhs)==1) { result_ = std::move(lhs); } + // -1 * a = -a + else if (expr_value(lhs) == -1.0) { + result_ = make_expression(loc, std::move(rhs)); + + } + else if (expr_value(lhs) == -1.0) { + result_ = make_expression(loc, std::move(lhs)); + } + // -a * -b = a * b + else if (auto l = lhs->is_unary(), r = rhs->is_unary(); l && r && l->op() == tok::minus && r->op() == tok::minus) { + result_ = make_expression(loc, + l->expression()->clone(), + r->expression()->clone()); + } + // a * -b = - (a * b) + else if (auto r = rhs->is_unary(); r && r->op() == tok::minus) { + result_ = make_expression(loc, + make_expression(loc, + std::move(lhs), + r->expression()->clone())); + } + // -a * b = - (a * b) + else if (auto l = lhs->is_unary(); l && l->op() == tok::minus) { + result_ = make_expression(loc, + make_expression(loc, + std::move(rhs), + l->expression()->clone())); + } else { result_ = make_expression(loc, std::move(lhs), std::move(rhs)); } @@ -549,7 +585,7 @@ class ConstantSimplifyVisitor: public Visitor { expression_ptr rhs = result(); if (is_number(lhs) && is_number(rhs)) { - as_number(loc, expr_value(lhs)+expr_value(rhs)); + as_number(loc, expr_value(lhs) + expr_value(rhs)); } else if (expr_value(lhs)==0) { result_ = std::move(rhs); @@ -557,6 +593,26 @@ class ConstantSimplifyVisitor: public Visitor { else if (expr_value(rhs)==0) { result_ = std::move(lhs); } + // Peephole optimisations + // -a + -b = -(a + b) + else if (auto l = lhs->is_unary(), r = rhs->is_unary(); l && r && l->op() == tok::minus && r->op() == tok::minus) { + result_ = make_expression(loc, + make_expression(loc, + l->expression()->clone(), + r->expression()->clone())); + } + // a + -b = a - b + else if (auto r = rhs->is_unary(); r && r->op() == tok::minus) { + result_ = make_expression(loc, + std::move(lhs), + r->expression()->clone()); + } + // -a + b = b - a + else if (auto l = lhs->is_unary(); l && l->op() == tok::minus) { + result_ = make_expression(loc, + std::move(rhs), + l->expression()->clone()); + } else { result_ = make_expression(loc, std::move(lhs), std::move(rhs)); } @@ -578,6 +634,25 @@ class ConstantSimplifyVisitor: public Visitor { else if (expr_value(rhs)==0) { result_ = std::move(lhs); } + // -a - -b = b - a + else if (auto l = lhs->is_unary(), r = rhs->is_unary(); l && r && l->op() == tok::minus && r->op() == tok::minus) { + result_ = make_expression(loc, + r->expression()->clone(), + l->expression()->clone()); + } + // a - -b = a + b + else if (auto r = rhs->is_unary(); r && r->op() == tok::minus) { + result_ = make_expression(loc, + std::move(lhs), + r->expression()->clone()); + } + // -a - b = - (a + b) + else if (auto l = lhs->is_unary(); l && l->op() == tok::minus) { + result_ = make_expression(loc, + make_expression(loc, + l->expression()->clone(), + std::move(rhs))); + } else { result_ = make_expression(loc, std::move(lhs), std::move(rhs)); } diff --git a/modcc/symge.cpp b/modcc/symge.cpp index ccc2afade9..f14c9144a1 100644 --- a/modcc/symge.cpp +++ b/modcc/symge.cpp @@ -43,8 +43,8 @@ sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym def ++qiter; qj = qiter==q.end()? q.npos: qiter->col; } - if (j!=c) { - u.push_back({j, define_sym(t1-t2)}); + if (j != c) { + u.push_back({j, define_sym(t1 - t2)}); } } return u; @@ -117,12 +117,12 @@ ARB_LIBMODCC_API std::vector> gj_reduce(sym_matrix& A while (true) { auto pivots = get_pivots(remaining_rows); - for (unsigned i = 0; icost[r2.row]; }); + [&](pivot r1, pivot r2) { return cost[r1.row] > cost[r2.row]; }); pivot p = pivots.back(); remaining_rows.erase(std::lower_bound(remaining_rows.begin(), remaining_rows.end(), p.row)); diff --git a/modcc/symge.hpp b/modcc/symge.hpp index a1419b7e4f..e52ed544c0 100644 --- a/modcc/symge.hpp +++ b/modcc/symge.hpp @@ -41,7 +41,6 @@ class symbol { }; // A `symbol_term` is either zero or a product of symbols. - struct symbol_term { symbol left, right; @@ -62,17 +61,12 @@ struct symbol_term_diff { inline symbol_term operator*(symbol a, symbol b) { return symbol_term{a, b}; } -inline symbol_term_diff operator-(symbol_term l, symbol_term r) { - return symbol_term_diff{l, r}; -} +inline symbol_term_diff operator-(symbol_term l, symbol_term r) { return symbol_term_diff{l, r}; } -inline symbol_term_diff operator-(symbol_term r) { - return symbol_term_diff{symbol_term{}, r}; -} +inline symbol_term_diff operator-(symbol_term r) { return symbol_term_diff{symbol_term{}, r}; } // Symbols are not re-assignable; they are created as primitive, or // have a definition in terms of a `symbol_term_diff`. - class symbol_table { private: struct table_entry { @@ -98,9 +92,7 @@ class symbol_table { return s; } - symbol define(const symbol_term_diff& def) { - return define("", def); - } + symbol define(const symbol_term_diff& def) { return define("", def); } symbol_term_diff get(symbol s) const { if (!defined(s)) throw symbol_error("symbol is primitive"); @@ -134,9 +126,7 @@ class symbol_table { void clear() { entries_.clear(); } }; -inline std::string name(symbol s) { - return s? s.table()->name(s): ""; -} +inline std::string name(symbol s) { return s? s.table()->name(s): ""; } inline symbol_term_diff definition(symbol s) { if (!s) throw symbol_error("invalid symbol"); From 80286da50babaf101c34be7de63d96678db34a49 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 23 Feb 2026 15:13:15 +0100 Subject: [PATCH 02/15] polish formatting a bit --- modcc/printer/cexpr_emit.cpp | 44 +++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index df081195a3..e3f080d91e 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -49,14 +49,14 @@ void CExprEmitter::emit_as_call(const char* sub, Expression* e1, Expression* e2) } void CExprEmitter::visit(NumberExpression* e) { - out_ << " " << as_c_double(e->value()); + out_ << as_c_double(e->value()); } void CExprEmitter::visit(UnaryExpression* e) { // Place a space in front of minus sign to avoid invalid // expressions of the form: (v[i]--67) static std::unordered_map unaryop_tbl = { - {tok::minus, " -"}, + {tok::minus, "-"}, {tok::exp, "exp"}, {tok::cos, "cos"}, {tok::sin, "sin"}, @@ -84,9 +84,17 @@ void CExprEmitter::visit(UnaryExpression* e) { // No need to use parenthesis for unary minus if inner expression is // not binary. - if (e->op()==tok::minus && !inner->is_binary()) { - out_ << op_spelling; - inner->accept(this); + if (e->op()==tok::minus) { + if (auto bin = inner->is_binary(); bin) { + out_ << op_spelling; + bool need_paren = Lexer::binop_precedence(bin->op()) < Lexer::binop_precedence(tok::times); + if (need_paren) out_ << '('; + inner->accept(this); + if (need_paren) out_ << ')'; + } else { + out_ << op_spelling; + inner->accept(this); + } } else if (e->op()==tok::step_right) { out_ << "((arb_value_type)(("; @@ -134,18 +142,18 @@ void CExprEmitter::visit(AssignmentExpression* e) { void CExprEmitter::visit(BinaryExpression* e) { static std::unordered_map binop_tbl = { - {tok::minus, "-"}, - {tok::plus, "+"}, + {tok::minus, " - "}, + {tok::plus, " + "}, {tok::times, "*"}, {tok::divide, "/"}, - {tok::lt, "<"}, - {tok::lte, "<="}, - {tok::gt, ">"}, - {tok::gte, ">="}, - {tok::equality, "=="}, - {tok::land, "&&"}, - {tok::lor, "||"}, - {tok::ne, "!="}, + {tok::lt, " < "}, + {tok::lte, " <= "}, + {tok::gt, " > "}, + {tok::gte, " >= "}, + {tok::equality, " == "}, + {tok::land, " && "}, + {tok::lor, " || "}, + {tok::ne, " != "}, {tok::min, "min"}, {tok::max, "max"}, {tok::pow, "pow"}, @@ -167,7 +175,7 @@ void CExprEmitter::visit(BinaryExpression* e) { auto need_paren = [op_prec](Expression* subexpr, bool assoc_side) -> bool { if (auto b = subexpr->is_binary()) { int sub_prec = Lexer::binop_precedence(b->op()); - return sub_prec SimdExprEmitter::mask_names_; void SimdExprEmitter::visit(NumberExpression* e) { out_ << " (double)" << as_c_double(e->value()); -} +} void SimdExprEmitter::visit(UnaryExpression* e) { static std::unordered_map unaryop_tbl = { @@ -249,7 +257,7 @@ void SimdExprEmitter::visit(UnaryExpression* e) { Expression* inner = e->expression(); auto iden = inner->is_identifier(); - bool is_scalar = iden && scalars_.count(iden->name()); + bool is_scalar = iden && scalars_.count(iden->name()); if (e->op()==tok::minus && is_scalar) { out_ << "simd_cast(-"; inner->accept(this); From 2826abec2a203d10d46a055ba81e4924a32377cb Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:04:01 +0100 Subject: [PATCH 03/15] formatting --- arbor/include/arbor/simd/neon.hpp | 47 +++++++++++++------------------ 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index 6743452f9c..fe5bbebeef 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -226,29 +226,25 @@ struct neon_double2 : implbase { static float64x2_t neg(const float64x2_t& a) { return vnegq_f64(a); } - static float64x2_t add(const float64x2_t& a, const float64x2_t& b) { - return vaddq_f64(a, b); - } + static float64x2_t add(const float64x2_t& a, const float64x2_t& b) { return vaddq_f64(a, b); } - static float64x2_t sub(const float64x2_t& a, const float64x2_t& b) { - return vsubq_f64(a, b); - } + static float64x2_t sub(const float64x2_t& a, const float64x2_t& b) { return vsubq_f64(a, b); } - static float64x2_t mul(const float64x2_t& a, const float64x2_t& b) { - return vmulq_f64(a, b); - } + static float64x2_t mul(const float64x2_t& a, const float64x2_t& b) { return vmulq_f64(a, b); } - static float64x2_t div(const float64x2_t& a, const float64x2_t& b) { - return vdivq_f64(a, b); - } + static float64x2_t fma(const float64x2_t& a, const float64x2_t& b, const float64x2_t& c) { return vfmaq_f64(c, a, b); } + + static float64x2_t fms(const float64x2_t& a, const float64x2_t& b, const float64x2_t& c) { return vfmsq_f64(c, a, b); } + + static float64x2_t div(const float64x2_t& a, const float64x2_t& b) { return vdivq_f64(a, b); } static float64x2_t logical_not(const float64x2_t& a) { return vreinterpretq_f64_u32(vmvnq_u32(vreinterpretq_u32_f64(a))); } static float64x2_t logical_and(const float64x2_t& a, const float64x2_t& b) { - return vreinterpretq_f64_u64( - vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); + return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a), + vreinterpretq_u64_f64(b))); } static float64x2_t logical_or(const float64x2_t& a, const float64x2_t& b) { @@ -280,7 +276,8 @@ struct neon_double2 : implbase { return vreinterpretq_f64_u64(vcleq_f64(a, b)); } - static float64x2_t ifelse(const float64x2_t& m, const float64x2_t& u, + static float64x2_t ifelse(const float64x2_t& m, + const float64x2_t& u, const float64x2_t& v) { return vbslq_f64(vreinterpretq_u64_f64(m), u, v); } @@ -289,22 +286,18 @@ struct neon_double2 : implbase { return vreinterpretq_f64_u64(vdupq_n_u64(-(int64)b)); } - static bool mask_element(const float64x2_t& u, int i) { - return static_cast(element(u, i)); - } + static bool mask_element(const float64x2_t& u, int i) { return static_cast(element(u, i)); } static float64x2_t mask_unpack(unsigned long long k) { // Only care about bottom two bits of k. - uint8x8_t b = vdup_n_u8((char)k); - uint8x8_t bl = vorr_u8(b, vdup_n_u8(0xfe)); - uint8x8_t bu = vorr_u8(b, vdup_n_u8(0xfd)); - uint8x16_t blu = vcombine_u8(bl, bu); - + uint8x8_t b = vdup_n_u8((char)k); + uint8x8_t bl = vorr_u8(b, vdup_n_u8(0xfe)); + uint8x8_t bu = vorr_u8(b, vdup_n_u8(0xfd)); + uint8x16_t blu = vcombine_u8(bl, bu); uint8x16_t ones = vdupq_n_u8(0xff); - uint64x2_t r = - vceqq_u64(vreinterpretq_u64_u8(ones), vreinterpretq_u64_u8(blu)); - - return vreinterpretq_f64_u64(r); + uint64x2_t res = vceqq_u64(vreinterpretq_u64_u8(ones), + vreinterpretq_u64_u8(blu)); + return vreinterpretq_f64_u64(res); } static void mask_set_element(float64x2_t& u, int i, bool b) { From 18370fe3b556601bf77d0371d0d97ddbdff48af0 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:06:02 +0100 Subject: [PATCH 04/15] Add and handle constraints --- modcc/printer/cexpr_emit.cpp | 73 +++++++++++++++++++++++++++++++++++- modcc/printer/cexpr_emit.hpp | 17 +++++---- modcc/printer/cprinter.hpp | 3 +- 3 files changed, 84 insertions(+), 9 deletions(-) diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index e3f080d91e..20d683f212 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -48,6 +48,16 @@ void CExprEmitter::emit_as_call(const char* sub, Expression* e1, Expression* e2) out_ << ')'; } +void CExprEmitter::emit_as_call(const char* sub, Expression* e1, Expression* e2, Expression* e3) { + out_ << sub << '('; + e1->accept(this); + out_ << ", "; + e2->accept(this); + out_ << ", "; + e3->accept(this); + out_ << ')'; +} + void CExprEmitter::visit(NumberExpression* e) { out_ << as_c_double(e->value()); } @@ -87,7 +97,7 @@ void CExprEmitter::visit(UnaryExpression* e) { if (e->op()==tok::minus) { if (auto bin = inner->is_binary(); bin) { out_ << op_spelling; - bool need_paren = Lexer::binop_precedence(bin->op()) < Lexer::binop_precedence(tok::times); + bool need_paren = true; //Lexer::binop_precedence(bin->op()) < Lexer::binop_precedence(tok::times); if (need_paren) out_ << '('; inner->accept(this); if (need_paren) out_ << ')'; @@ -168,6 +178,23 @@ void CExprEmitter::visit(BinaryExpression* e) { auto lhs = e->lhs(); const char* op_spelling = binop_tbl.at(e->op()); + if (e->op() == tok::plus) { + if (auto l = e->lhs()->is_binary(); l && l->op() == tok::times) { + emit_as_call("fma", l->lhs(), l->rhs(), rhs); + return; + } + if (auto r = e->rhs()->is_binary(); r && r->op() == tok::times) { + emit_as_call("fma", r->lhs(), r->rhs(), lhs); + return; + } + } + else if (e->op() == tok::minus) { + if (auto r = e->rhs()->is_binary(); r && r->op() == tok::times) { + emit_as_call("fms", r->lhs(), r->rhs(), lhs); + return; + } + } + if (e->is_infix()) { associativityKind assoc = Lexer::operator_associativity(e->op()); int op_prec = Lexer::binop_precedence(e->op()); @@ -281,6 +308,32 @@ std::string id_prefix(IdentifierExpression* id) { return id->name(); } +void SimdExprEmitter::emit_fused(const std::string& name, Expression* a, Expression* b, Expression* c) { + auto check_cast = [this](Expression* expr) { + return expr->is_number() || (expr->is_identifier() && scalars_.count(expr->is_identifier()->name())); + }; + + bool need_cast = false; + out_ << name << '('; + need_cast = check_cast(a); + if (need_cast) out_ << "simd_cast("; + a->accept(this); + if (need_cast) out_ << ')'; + out_ << ", "; + + need_cast = check_cast(b); + if (need_cast) out_ << "simd_cast("; + b->accept(this); + if (need_cast) out_ << ')'; + out_ << ", "; + + need_cast = check_cast(c); + if (need_cast) out_ << "simd_cast("; + c->accept(this); + if (need_cast) out_ << ')'; + out_ << ')'; +} + void SimdExprEmitter::visit(BinaryExpression* e) { static std::unordered_map func_tbl = { @@ -333,6 +386,24 @@ void SimdExprEmitter::visit(BinaryExpression* e) { const char *op_spelling = binop_tbl.at(e->op()); const char *func_spelling = func_tbl.at(e->op()); + if (e->op() == tok::plus) { + if (auto l = lhs->is_binary(); l && l->op() == tok::times) { + // emit_fused("S::fma", l->lhs(), l->rhs(), rhs); + // return; + } + if (auto r = rhs->is_binary(); r && r->op() == tok::times) { + // emit_fused("S::fma", r->lhs(), r->rhs(), lhs); + // return; + } + } + else if (e->op() == tok::minus) { + if (auto r = rhs->is_binary(); r && r->op() == tok::times) { + // emit_fused("S::fms", r->lhs(), r->rhs(), lhs); + // return; + } + } + + if (auto id = rhs->is_identifier()) { rhs_name = id->name(); rhs_pfxd = id_prefix(id); diff --git a/modcc/printer/cexpr_emit.hpp b/modcc/printer/cexpr_emit.hpp index 4e5fa873fa..40a0ba463c 100644 --- a/modcc/printer/cexpr_emit.hpp +++ b/modcc/printer/cexpr_emit.hpp @@ -32,6 +32,7 @@ class ARB_LIBMODCC_API CExprEmitter: public Visitor { void emit_as_call(const char* sub, Expression*); void emit_as_call(const char* sub, Expression*, Expression*); + void emit_as_call(const char* sub, Expression*, Expression*, Expression*); }; inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { @@ -41,13 +42,13 @@ inline void cexpr_emit(Expression* e, std::ostream& out, Visitor* fallback) { class ARB_LIBMODCC_API SimdExprEmitter: public CExprEmitter { public: - SimdExprEmitter( - std::ostream& out, - bool is_indirect, - std::string input_mask, - const std::unordered_set& scalars, - Visitor* fallback): - CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask), scalars_(scalars), fallback_(fallback) {} + SimdExprEmitter(std::ostream& out, + bool is_indirect, + std::string input_mask, + const std::unordered_set& scalars, + Visitor* fallback): + CExprEmitter(out, fallback), is_indirect_(is_indirect), input_mask_(input_mask), scalars_(scalars), fallback_(fallback) + {} using CExprEmitter::visit; void visit(BlockExpression *e) override; void visit(CallExpression *e) override; @@ -67,6 +68,8 @@ class ARB_LIBMODCC_API SimdExprEmitter: public CExprEmitter { Visitor* fallback_; private: + void emit_fused(const std::string&, Expression*, Expression*, Expression*); + std::string make_unique_var(scope_ptr scope, std::string prefix) { for (int i = 0;; ++i) { std::string name = prefix + std::to_string(i) + "_"; diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index 725356ba81..84c48e485f 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -43,7 +43,8 @@ class ARB_LIBMODCC_API CPrinter: public Visitor { enum class simd_expr_constraint{ constant, contiguous, - other + independent, + none, }; struct ApiFlags { From c8f760221a3c28f37be7dd8db8767e03198f4234 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:07:45 +0100 Subject: [PATCH 05/15] Clean-up fms/fma --- arbor/include/arbor/simd/implbase.hpp | 12 ++++-------- arbor/include/arbor/simd/simd.hpp | 19 ++++++++++++++++--- test/unit/test_simd.cpp | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index fd62ad91bf..e8a97546a3 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -239,15 +239,11 @@ struct implbase { } static vector_type fma(const vector_type& u, const vector_type& v, const vector_type& w) { - store a, b, c, r; - I::copy_to(u, a); - I::copy_to(v, b); - I::copy_to(w, c); + return I::add(w, I::mul(u, v)); + } - for (unsigned i = 0; i logical_not(const detail::simd_mask_impl& a) { } template -detail::simd_impl fma(const detail::simd_impl& a, detail::simd_impl b, detail::simd_impl c) { +detail::simd_impl fma(detail::simd_impl a, detail::simd_impl b, detail::simd_impl c) { return detail::simd_impl::wrap(T::fma(a.value_, b.value_, c.value_)); } +template +detail::simd_impl fms(detail::simd_impl a, detail::simd_impl b, detail::simd_impl c) { + return detail::simd_impl::wrap(T::fms(a.value_, b.value_, c.value_)); +} + namespace detail { /// Indirect Expressions template @@ -570,10 +575,14 @@ namespace detail { return simd_impl::wrap(Impl::div(a.value_, b.value_)); } - friend simd_impl fma(const simd_impl& a, simd_impl b, simd_impl c) { + friend simd_impl fma(simd_impl a, simd_impl b, simd_impl c) { return simd_impl::wrap(Impl::fma(a.value_, b.value_, c.value_)); } + friend simd_impl fms(simd_impl a, simd_impl b, simd_impl c) { + return simd_impl::wrap(Impl::fms(a.value_, b.value_, c.value_)); + } + // Lane-wise relational operations. friend simd_mask operator==(const simd_impl& a, const simd_impl& b) { @@ -692,7 +701,11 @@ namespace detail { #undef ARB_DECLARE_BINARY_COMPARISON_ template - friend simd_impl arb::simd::fma(const simd_impl& a, simd_impl b, simd_impl c); + friend simd_impl arb::simd::fma(simd_impl a, simd_impl b, simd_impl c); + + template + friend simd_impl arb::simd::fms(simd_impl a, simd_impl b, simd_impl c); + // Declare Indirect/Indirect indexed/Where Expression copy function as friends diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 0a6e116c0d..55a99d5188 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -330,7 +330,7 @@ TYPED_TEST_P(simd_value, arithmetic) { #endif (fma(us, vs, ws)).copy_to(r); - EXPECT_TRUE(testing::seq_eq(fma_u_v_w, r)); + EXPECT_TRUE(testing::seq_almost_eq(fma_u_v_w, r)); } } From f39eeb941c652154b8ed0d70b948fafac2e4ba22 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:08:56 +0100 Subject: [PATCH 06/15] handle new cases, more fmt --- modcc/printer/cprinter.cpp | 126 +++++++++++++++++++++++-------------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index b3b87649ee..30933aa2a3 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -730,21 +730,30 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con } else { switch (d.index_var_kind) { - case index_kind::node: { + case index_kind::node: { switch (constraint) { - case simd_expr_constraint::contiguous: + case simd_expr_constraint::contiguous: { out << ";\n" << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) << " + " << node_index_i_name(d) << ", simd_width_));\n"; break; - case simd_expr_constraint::constant: + } + case simd_expr_constraint::constant: { out << " = simd_cast(" << data_via_ppack(d) << "[" << node_index_i_name(d) << "]);\n"; break; - default: + } + case simd_expr_constraint::independent: { out << ";\n" << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) - << ", " << node_index_i_name(d) << ", simd_width_, constraint_category_));\n"; + << ", " << node_index_i_name(d) << ", simd_width_, index_constraint::independent));\n"; + break; + } + case simd_expr_constraint::none: { + out << ";\n" + << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) + << ", " << node_index_i_name(d) << ", simd_width_, index_constraint::none));\n"; + } } break; } @@ -767,6 +776,17 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con EXIT(out); } +std::string constraint_category(simd_expr_constraint cat) { + switch (cat) { + case simd_expr_constraint::constant: return "index_constraint::constant"; + case simd_expr_constraint::contiguous: return "index_constraint::contiguous"; + case simd_expr_constraint::independent: return "index_constraint::independent"; + case simd_expr_constraint::none: return "index_constraint::none"; + } + // this can never be reached + throw std::runtime_error("Impossible"); +} + void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* external, simd_expr_constraint constraint, @@ -801,6 +821,8 @@ void emit_simd_state_update(std::ostream& out, std::string weight = (d.always_use_weight || !flags.is_point) ? "w_" : "simd_cast(1.0)"; + auto index_category = constraint_category(constraint); + if (d.additive && flags.use_additive) { if (d.index_var_kind == index_kind::node) { if (constraint == simd_expr_constraint::contiguous) { @@ -811,14 +833,14 @@ void emit_simd_state_update(std::ostream& out, // We need this instead of simple assignment! out << fmt::format("{{\n" " simd_value t_{}0_ = simd_cast(0.0);\n" - " assign(t_{}0_, indirect({}, simd_cast({}), simd_width_, constraint_category_));\n" + " assign(t_{}0_, indirect({}, simd_cast({}), simd_width_, {}));\n" " {} = S::sub({}, t_{}0_);\n" - " indirect({}, simd_cast({}), simd_width_, constraint_category_) += S::mul({}, {});\n" + " indirect({}, simd_cast({}), simd_width_, {}) += S::mul({}, {});\n" "}}\n", name, - name, data, node, + name, data, node, index_category, scaled, scaled, name, - data, node, weight, scaled); + data, node, index_category, weight, scaled); } } else { @@ -839,14 +861,14 @@ void emit_simd_state_update(std::ostream& out, // We need this instead of simple assignment! out << fmt::format("{{\n" " simd_value t_{}0_ = simd_cast(0.0);\n" - " assign(t_{}0_, indirect({}, simd_cast({}), simd_width_, constraint_category_));\n" + " assign(t_{}0_, indirect({}, simd_cast({}), simd_width_, {}));\n" " {} = S::sub({}, t_{}0_);\n" - " indirect({}, simd_cast({}), simd_width_, constraint_category_) += S::mul({}, {});\n" + " indirect({}, simd_cast({}), simd_width_, {}) += S::mul({}, {});\n" "}}\n", name, - name, data, node, + name, data, node, index_category, scaled, scaled, name, - data, node, weight, scaled); + data, node, index_category, weight, scaled); } } else { @@ -858,41 +880,49 @@ void emit_simd_state_update(std::ostream& out, if (d.index_var_kind == index_kind::node) { std::string tempvar = "t_" + external->name(); switch (constraint) { - case simd_expr_constraint::contiguous: - out << "simd_value " << tempvar << ";\n" - << "assign(" << tempvar << ", indirect(" << data << " + " << node << ", simd_width_));\n" - << tempvar << " = S::fma(" << weight << ", " << scaled << ", " << tempvar << ");\n" - << "indirect(" << data << " + " << node << ", simd_width_) = " << tempvar << ";\n"; - break; - case simd_expr_constraint::constant: - out << "indirect(" << data << ", simd_cast(" << node << "), simd_width_, constraint_category_) += S::mul(" << weight << ", " << scaled << ");\n"; - break; - default: - out << "indirect(" << data << ", " << node << ", simd_width_, constraint_category_) += S::mul(" << weight << ", " << scaled << ");\n"; + case simd_expr_constraint::contiguous: + out << "simd_value " << tempvar << ";\n" + << "assign(" << tempvar << ", indirect(" << data << " + " << node << ", simd_width_));\n" + << tempvar << " = S::fma(" << weight << ", " << scaled << ", " << tempvar << ");\n" + << "indirect(" << data << " + " << node << ", simd_width_) = " << tempvar << ";\n"; + break; + case simd_expr_constraint::constant: + out << fmt::format("indirect({}, simd_cast({}), simd_width_, {}) += S::mul({}, {});\n", + data, node, index_category, weight, scaled); + break; + default: + out << fmt::format("indirect({}, {}, simd_width_, {}) += S::mul({}, {});\n", + data, node, index_category, weight, scaled); } } else { - out << "indirect(" << data << ", " << index << ", simd_width_, index_constraint::none) += S::mul(" << weight << ", " << scaled << ");\n"; + out << fmt::format("indirect({}, {}, simd_width_, index_constraint::none) += S::mul({}, {});\n", + data, index, weight, scaled); } } else if (d.index_var_kind == index_kind::node) { switch (constraint) { case simd_expr_constraint::contiguous: - out << "indirect(" << data << " + " << node << ", simd_width_) = " << scaled << ";\n"; + out << fmt::format("indirect({} + {}, simd_width_) = {};\n", + data, node, scaled); break; case simd_expr_constraint::constant: - out << "indirect(" << data << ", simd_cast(" << node << "), simd_width_, constraint_category_) = " << scaled << ";\n"; + out << fmt::format("indirect({}, simd_cast({}), simd_width_, {}) = {};\n", + data, node, index_category, scaled); break; default: - out << "indirect(" << data << ", " << node << ", simd_width_, constraint_category_) = " << scaled << ";\n"; + out << fmt::format("indirect({}, {}, simd_width_, {}) = {};\n", + data, node, index_category, scaled); } } else { - out << "indirect(" << data << ", " << index << ", simd_width_, index_constraint::none) = " << scaled << ";\n"; + out << fmt::format("indirect({}, {}, simd_width_, index_constraint::none) = {};\n", + data, index, scaled); } EXIT(out); } -void emit_simd_index_initialize(std::ostream& out, const std::list& indices, +void emit_simd_index_initialize(std::ostream& out, + const std::list& indices, simd_expr_constraint constraint) { ENTER(out); for (auto& index: indices) { @@ -921,11 +951,17 @@ void emit_simd_index_initialize(std::ostream& out, const std::list& out << "auto " << source_index_i_name(index) << " = simd_cast(" << source_var(index) << "[" << index.index_name << "]);\n"; break; - default: + case simd_expr_constraint::independent:{ + out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(" << source_var(index) + << ", " << index.index_name << ", simd_width_, index_constraint::independent));\n"; + break; + } + case simd_expr_constraint::none: { out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(" << source_var(index) - << ", " << index.index_name << ", simd_width_, constraint_category_));\n"; + << ", " << index.index_name << ", simd_width_, index_constraint::none));\n"; break; } + } break; } default: { @@ -938,14 +974,13 @@ void emit_simd_index_initialize(std::ostream& out, const std::list& EXIT(out); } -void emit_simd_body_for_loop( - std::ostream& out, - BlockExpression* body, - const std::vector& indexed_vars, - const std::vector& scalars, - const std::list& indices, - const simd_expr_constraint& constraint, - const ApiFlags& flags) { +void emit_simd_body_for_loop(std::ostream& out, + BlockExpression* body, + const std::vector& indexed_vars, + const std::vector& scalars, + const std::list& indices, + const simd_expr_constraint& constraint, + const ApiFlags& flags) { ENTER(out); emit_simd_index_initialize(out, indices, constraint); @@ -972,8 +1007,7 @@ void emit_simd_for_loop_per_constraint(std::ostream& out, BlockExpression* body, std::string underlying_constraint_name, const ApiFlags& flags) { ENTER(out); - out << fmt::format("constraint_category_ = index_constraint::{1};\n" - "for (auto i_ = 0ul; i_ < {0}index_constraints_n_{1}; i_++) {{\n" + out << fmt::format("for (auto i_ = 0ul; i_ < {0}index_constraints_n_{1}; i_++) {{\n" " arb_index_type index_ = {0}index_constraints_{1}[i_];\n", pp_var_pfx, underlying_constraint_name) @@ -1006,8 +1040,6 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, out << "PPACK_IFACE_BLOCK;\n"; out << "assert(simd_width_ <= (unsigned)S::width(simd_cast(0)));\n"; if (!indices.empty()) { - out << "index_constraint constraint_category_;\n\n"; - //Generate for loop for all contiguous simd_vectors simd_expr_constraint constraint = simd_expr_constraint::contiguous; std::string underlying_constraint = "contiguous"; @@ -1015,13 +1047,13 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); //Generate for loop for all independent simd_vectors - constraint = simd_expr_constraint::other; + constraint = simd_expr_constraint::independent; underlying_constraint = "independent"; emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); //Generate for loop for all simd_vectors that have no optimizing constraints - constraint = simd_expr_constraint::other; + constraint = simd_expr_constraint::none; underlying_constraint = "none"; emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); @@ -1035,7 +1067,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, else { // We may nonetheless need to read a global scalar indexed variable. for (auto& sym: scalar_indexed_vars) { - emit_simd_state_read(out, sym, simd_expr_constraint::other, flags); + emit_simd_state_read(out, sym, simd_expr_constraint::none, flags); } out << fmt::format("for (arb_size_type i_ = 0; i_ < {}width; i_ += simd_width_) {{\n", From f8829096e2a21093a68492a5751ae95b0eceafbc Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:05:01 +0100 Subject: [PATCH 07/15] more fmt --- modcc/printer/cprinter.cpp | 48 +++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 30933aa2a3..a7d7e76797 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -716,8 +716,7 @@ void SimdPrinter::visit(BlockExpression* block) { void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_constraint constraint, const ApiFlags& flags) { ENTER(out); - out << "simd_value " << local->name(); - + auto name = local->name(); auto write_voltage = local->external_variable()->data_source() == sourceKind::voltage && flags.can_write_voltage; auto is_additive = local->is_write() && decode_indexed_variable(local->external_variable()).additive; @@ -725,53 +724,55 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con if (local->is_read() || is_additive || write_voltage) { auto d = decode_indexed_variable(local->external_variable()); if (d.scalar()) { - out << " = simd_cast(" << pp_var_pfx << d.data_var - << "[0]);\n"; + out << fmt::format("simd_value {} = simd_cast({}{}[0]);\n", + name, pp_var_pfx, d.data_var); } else { switch (d.index_var_kind) { case index_kind::node: { switch (constraint) { case simd_expr_constraint::contiguous: { - out << ";\n" - << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) - << " + " << node_index_i_name(d) << ", simd_width_));\n"; + out << fmt::format("simd_value {};\n" + "assign({}, indirect({} + {}, simd_width_));\n", + name, name, data_via_ppack(d), node_index_i_name(d)); + break; } case simd_expr_constraint::constant: { - out << " = simd_cast(" << data_via_ppack(d) - << "[" << node_index_i_name(d) << "]);\n"; + out << fmt::format("simd_value {} = simd_cast({}[{}]);\n", + name, data_via_ppack(d), node_index_i_name(d)); break; } case simd_expr_constraint::independent: { - out << ";\n" - << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) - << ", " << node_index_i_name(d) << ", simd_width_, index_constraint::independent));\n"; + out << fmt::format("simd_value {};\n" + "assign({}, indirect({}, {}, simd_width_, index_constraint::independent));\n", + name, name, data_via_ppack(d), node_index_i_name(d)); break; } case simd_expr_constraint::none: { - out << ";\n" - << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) - << ", " << node_index_i_name(d) << ", simd_width_, index_constraint::none));\n"; + out << fmt::format("simd_value {};\n" + "assign({}, indirect({}, {}, simd_width_, index_constraint::none));\n", + name, name, data_via_ppack(d), node_index_i_name(d)); } } break; } default: { - out << ";\n" - << "assign(" << local->name() << ", indirect(" << data_via_ppack(d) - << ", " << index_i_name(d.outer_index_var()) << ", simd_width_, index_constraint::none));\n"; + out << fmt::format("simd_value {};\n" + "assign({}, indirect({}, {}, simd_width_, index_constraint::none));\n", + name, name, data_via_ppack(d), index_i_name(d.outer_index_var())); break; } } } if (d.scale != 1) { - out << local->name() << " = S::mul(" << local->name() << ", simd_cast(" << d.scale << "));\n"; + out << fmt::format("{} = S::mul({}, simd_cast({}));\n", + name, name, d.scale); } } else { - out << " = simd_cast(0);\n"; + out << fmt::format("simd_value {} = simd_cast(0);\n", name); } EXIT(out); } @@ -1043,25 +1044,21 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, //Generate for loop for all contiguous simd_vectors simd_expr_constraint constraint = simd_expr_constraint::contiguous; std::string underlying_constraint = "contiguous"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); //Generate for loop for all independent simd_vectors constraint = simd_expr_constraint::independent; underlying_constraint = "independent"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); //Generate for loop for all simd_vectors that have no optimizing constraints constraint = simd_expr_constraint::none; underlying_constraint = "none"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); //Generate for loop for all constant simd_vectors constraint = simd_expr_constraint::constant; underlying_constraint = "constant"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); } else { @@ -1075,8 +1072,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, << indent << simdprint(body, scalars) << popindent - << - "}\n"; + << "}\n"; } } EXIT(out); From 5e5404f3c70efcc7b995ad3ae58657263dec6da0 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 2 Mar 2026 09:53:45 +0100 Subject: [PATCH 08/15] Largely clean-up stream usage. --- modcc/printer/cprinter.cpp | 61 ++++++++++++++------------------ test/unit-modcc/test_symdiff.cpp | 2 +- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index a7d7e76797..43013741f4 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -603,7 +603,7 @@ void SimdPrinter::visit(LocalVariable* sym) { void SimdPrinter::visit(VariableExpression *sym) { ENTERM(out_, "variable"); if (sym->is_range()) { - auto index = is_indirect_? "index_": "i_"; + auto index = is_indirect_? "index_" : "i_"; out_ << "simd_cast(indirect(" << pp_var_pfx << sym->name() << "+" << index << ", simd_width_))"; } else { @@ -637,8 +637,7 @@ void SimdPrinter::visit(AssignmentExpression* e) { // are scalars and read-only. if (lhs->is_variable() && lhs->is_variable()->is_range()) { out_ << "indirect(" << pfx << lhs->name() << "+" << index << ", simd_width_) = "; - if (!input_mask_.empty()) - out_ << "S::where(" << input_mask_ << ", "; + if (!input_mask_.empty()) out_ << "S::where(" << input_mask_ << ", "; // If the rhs is a scalar identifier or a number, it needs to be cast to a vector. auto id = e->rhs()->is_identifier(); @@ -648,9 +647,7 @@ void SimdPrinter::visit(AssignmentExpression* e) { if (cast) out_ << "simd_cast("; e->rhs()->accept(this); if (cast) out_ << ")"; - - if (!input_mask_.empty()) - out_ << ")"; + if (!input_mask_.empty()) out_ << ")"; } else if (lhs->is_variable() && !lhs->is_variable()->is_range()) { throw (compiler_exception("Should not be trying to assign a non-range variable " + lhs->to_string(), lhs->location())); @@ -676,10 +673,7 @@ void SimdPrinter::visit(AssignmentExpression* e) { void SimdPrinter::visit(CallExpression* e) { ENTERM(out_, "call"); - if(is_indirect_) - out_ << e->name() << "(pp, index_"; - else - out_ << e->name() << "(pp, i_"; + out_ << e->name() << "(pp, " << (is_indirect_ ? "index_" : "i_"); for (auto& arg: e->args()) { out_ << ", "; arg->accept(this); @@ -879,13 +873,13 @@ void emit_simd_state_update(std::ostream& out, } else if (d.accumulate) { if (d.index_var_kind == index_kind::node) { - std::string tempvar = "t_" + external->name(); switch (constraint) { case simd_expr_constraint::contiguous: - out << "simd_value " << tempvar << ";\n" - << "assign(" << tempvar << ", indirect(" << data << " + " << node << ", simd_width_));\n" - << tempvar << " = S::fma(" << weight << ", " << scaled << ", " << tempvar << ");\n" - << "indirect(" << data << " + " << node << ", simd_width_) = " << tempvar << ";\n"; + out << fmt::format("simd_value t_{0};\n" + "assign(t_{0}, indirect({1} + {2}, simd_width_));\n" + "t_{0} = S::fma({3}, {4}, t_{0});\n" + "indirect({1} + {2}, simd_width_) = t_{0};\n", + external->name(), data, node, weight, scaled); break; case simd_expr_constraint::constant: out << fmt::format("indirect({}, simd_cast({}), simd_width_, {}) += S::mul({}, {});\n", @@ -932,11 +926,12 @@ void emit_simd_index_initialize(std::ostream& out, switch (constraint) { case simd_expr_constraint::contiguous: case simd_expr_constraint::constant: - out << "auto " << source_index_i_name(index) << " = " << source_var(index) << "[" << index.index_name << "];\n"; + out << fmt::format("auto {} = {}[{}];\n", + source_index_i_name(index), source_var(index), index.index_name); break; default: - out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(&" << source_var(index) - << "[0] + " << index.index_name << ", simd_width_));\n"; + out << fmt::format("auto {} = simd_cast(indirect(&{}[0] + {}, simd_width_));\n", + source_index_i_name(index), source_var(index), index.index_name); break; } break; @@ -945,29 +940,27 @@ void emit_simd_index_initialize(std::ostream& out, // Treat like reading a state variable. switch (constraint) { case simd_expr_constraint::contiguous: - out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(" << source_var(index) - << " + " << index.index_name << ", simd_width_));\n"; + out << fmt::format("auto {} = simd_cast(indirect({} + {}, simd_width_));\n", + source_index_i_name(index), source_var(index), index.index_name); break; case simd_expr_constraint::constant: - out << "auto " << source_index_i_name(index) << " = simd_cast(" << source_var(index) - << "[" << index.index_name << "]);\n"; + out << fmt::format("auto {} = simd_cast({}[{}]);\n", + source_index_i_name(index), source_var(index), index.index_name); break; - case simd_expr_constraint::independent:{ - out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(" << source_var(index) - << ", " << index.index_name << ", simd_width_, index_constraint::independent));\n"; + case simd_expr_constraint::independent: + out << fmt::format("auto {} = simd_cast(indirect({}, {}, simd_width_, index_constraint::independent));\n", + source_index_i_name(index), source_var(index), index.index_name); break; - } - case simd_expr_constraint::none: { - out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(" << source_var(index) - << ", " << index.index_name << ", simd_width_, index_constraint::none));\n"; + case simd_expr_constraint::none: + out << fmt::format("auto {} = simd_cast(indirect({}, {}, simd_width_, index_constraint::none));\n", + source_index_i_name(index), source_var(index), index.index_name); break; } - } break; } default: { - out << "auto " << source_index_i_name(index) << " = simd_cast(indirect(&" << source_var(index) - << "[0] + " << index.index_name << ", simd_width_));\n"; + out << fmt::format("auto {} = simd_cast(indirect(&{}[0] + {}, simd_width_));\n", + source_index_i_name(index), source_var(index), index.index_name); break; } } @@ -1038,8 +1031,8 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, } } if (!body->statements().empty()) { - out << "PPACK_IFACE_BLOCK;\n"; - out << "assert(simd_width_ <= (unsigned)S::width(simd_cast(0)));\n"; + out << "PPACK_IFACE_BLOCK;\n" + << "assert(simd_width_ <= (unsigned)S::width(simd_cast(0)));\n"; if (!indices.empty()) { //Generate for loop for all contiguous simd_vectors simd_expr_constraint constraint = simd_expr_constraint::contiguous; diff --git a/test/unit-modcc/test_symdiff.cpp b/test/unit-modcc/test_symdiff.cpp index 60b2d31e7c..7dccd24359 100644 --- a/test/unit-modcc/test_symdiff.cpp +++ b/test/unit-modcc/test_symdiff.cpp @@ -320,7 +320,7 @@ TEST(linear_test, homogeneous) { EXPECT_TRUE(r.is_linear); EXPECT_TRUE(r.is_homogeneous); EXPECT_FALSE(r.monolinear()); - EXPECT_EXPR_EQ(r.coef["x"], "-a+2"_expr); + EXPECT_EXPR_EQ(r.coef["x"], "2 - a"_expr); EXPECT_EXPR_EQ(r.coef["y"], "1"_expr); } From 0bfa8c959d44e26d482a7cd873062b9187191af3 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 2 Mar 2026 11:26:40 +0100 Subject: [PATCH 09/15] Reenable FMA, _very_ slightly losen test bounds for fma. --- modcc/printer/cexpr_emit.cpp | 25 +++++++++++++------------ test/unit/common.hpp | 30 ++++++++++++++++++++++++++++++ test/unit/test_simd.cpp | 8 ++++---- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 20d683f212..b662dee246 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -188,12 +188,13 @@ void CExprEmitter::visit(BinaryExpression* e) { return; } } - else if (e->op() == tok::minus) { - if (auto r = e->rhs()->is_binary(); r && r->op() == tok::times) { - emit_as_call("fms", r->lhs(), r->rhs(), lhs); - return; - } - } + // there is no FMS in C++ + // else if (e->op() == tok::minus) { + // if (auto r = e->rhs()->is_binary(); r && r->op() == tok::times) { + // emit_as_call("fms", r->lhs(), r->rhs(), lhs); + // return; + // } + // } if (e->is_infix()) { associativityKind assoc = Lexer::operator_associativity(e->op()); @@ -388,18 +389,18 @@ void SimdExprEmitter::visit(BinaryExpression* e) { if (e->op() == tok::plus) { if (auto l = lhs->is_binary(); l && l->op() == tok::times) { - // emit_fused("S::fma", l->lhs(), l->rhs(), rhs); - // return; + emit_fused("S::fma", l->lhs(), l->rhs(), rhs); + return; } if (auto r = rhs->is_binary(); r && r->op() == tok::times) { - // emit_fused("S::fma", r->lhs(), r->rhs(), lhs); - // return; + emit_fused("S::fma", r->lhs(), r->rhs(), lhs); + return; } } else if (e->op() == tok::minus) { if (auto r = rhs->is_binary(); r && r->op() == tok::times) { - // emit_fused("S::fms", r->lhs(), r->rhs(), lhs); - // return; + emit_fused("S::fms", r->lhs(), r->rhs(), lhs); + return; } } diff --git a/test/unit/common.hpp b/test/unit/common.hpp index 23a08da975..aa8867159e 100644 --- a/test/unit/common.hpp +++ b/test/unit/common.hpp @@ -246,6 +246,36 @@ ::testing::AssertionResult seq_almost_eq(Seq1&& seq1, Seq2&& seq2) { return ::testing::AssertionSuccess(); } +template +::testing::AssertionResult seq_almost_eq(Seq1&& seq1, Seq2&& seq2, FPType eps) { + using std::begin; + using std::end; + + auto i1 = begin(seq1); + auto i2 = begin(seq2); + + auto e1 = end(seq1); + auto e2 = end(seq2); + + for (std::size_t j = 0; i1!=e1 && i2!=e2; ++i1, ++i2, ++j) { + + auto v1 = *i1; + auto v2 = *i2; + + // Cast to FPType to avoid warnings about lowering conversion + // if FPType has lower precision than Seq{12}::value_type. + if (std::abs(v1 - v2) > eps) { + ::testing::AssertionFailure() << "sequence values " << v1 << " and " << v2 << " exceed max deviation at index " << j; + } + } + + if (i1!=e1 || i2!=e2) { + return ::testing::AssertionFailure() << "sequences differ in length"; + } + return ::testing::AssertionSuccess(); +} + + template inline bool generic_isnan(const V& x) { return false; } inline bool generic_isnan(float x) { return std::isnan(x); } diff --git a/test/unit/test_simd.cpp b/test/unit/test_simd.cpp index 55a99d5188..65bd7bcbde 100644 --- a/test/unit/test_simd.cpp +++ b/test/unit/test_simd.cpp @@ -294,7 +294,7 @@ TYPED_TEST_P(simd_value, arithmetic) { for (unsigned i = 0; i::value) { EXPECT_TRUE(testing::seq_almost_eq(u_divide_v, r)); } @@ -329,8 +328,9 @@ TYPED_TEST_P(simd_value, arithmetic) { EXPECT_TRUE(testing::seq_eq(u_divide_v, r)); #endif - (fma(us, vs, ws)).copy_to(r); - EXPECT_TRUE(testing::seq_almost_eq(fma_u_v_w, r)); + fma(us, vs, ws).copy_to(r); + // this catches all deviations larger than machine delta + EXPECT_TRUE(testing::seq_almost_eq(fma_u_v_w, r, 0.0)); } } From eebc18933754c7f2a0c6aaef8e101c51e081b92c Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 2 Mar 2026 11:41:30 +0100 Subject: [PATCH 10/15] Neon's FMS is too weird. --- arbor/include/arbor/simd/avx.hpp | 8 ++++---- arbor/include/arbor/simd/implbase.hpp | 4 ---- arbor/include/arbor/simd/neon.hpp | 2 -- arbor/include/arbor/simd/simd.hpp | 15 --------------- modcc/printer/cexpr_emit.cpp | 17 +---------------- 5 files changed, 5 insertions(+), 41 deletions(-) diff --git a/arbor/include/arbor/simd/avx.hpp b/arbor/include/arbor/simd/avx.hpp index 1724021d70..320db3a920 100644 --- a/arbor/include/arbor/simd/avx.hpp +++ b/arbor/include/arbor/simd/avx.hpp @@ -898,6 +898,10 @@ struct avx2_double4: avx_double4 { r))); } + static __m256d fms(const __m256d& a, const __m256d& b, const __m256d& c) { + return _mm256_fmsub_pd(a, b, c); + } + protected: static __m128i lo_epi32(__m256i a) { a = _mm256_shuffle_epi32(a, 0x08); @@ -929,10 +933,6 @@ struct avx2_double4: avx_double4 { return fma(x, horner1(x, tail...), broadcast(a0)); } - static __m256d fms(const __m256d& a, const __m256d& b, const __m256d& c) { - return _mm256_fmsub_pd(a, b, c); - } - // Compute 2.0^n. // Overrides avx_double4::exp2int. static __m256d exp2int(__m128i n) { diff --git a/arbor/include/arbor/simd/implbase.hpp b/arbor/include/arbor/simd/implbase.hpp index e8a97546a3..969b1e481a 100644 --- a/arbor/include/arbor/simd/implbase.hpp +++ b/arbor/include/arbor/simd/implbase.hpp @@ -242,10 +242,6 @@ struct implbase { return I::add(w, I::mul(u, v)); } - static vector_type fms(const vector_type& u, const vector_type& v, const vector_type& w) { - return I::sub(w, I::mul(u, v)); - } - static mask_type cmp_eq(const vector_type& u, const vector_type& v) { store a, b; mask_store r; diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index fe5bbebeef..d1968b6135 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -233,8 +233,6 @@ struct neon_double2 : implbase { static float64x2_t mul(const float64x2_t& a, const float64x2_t& b) { return vmulq_f64(a, b); } static float64x2_t fma(const float64x2_t& a, const float64x2_t& b, const float64x2_t& c) { return vfmaq_f64(c, a, b); } - - static float64x2_t fms(const float64x2_t& a, const float64x2_t& b, const float64x2_t& c) { return vfmsq_f64(c, a, b); } static float64x2_t div(const float64x2_t& a, const float64x2_t& b) { return vdivq_f64(a, b); } diff --git a/arbor/include/arbor/simd/simd.hpp b/arbor/include/arbor/simd/simd.hpp index 9de9bd8642..b5bbf36696 100644 --- a/arbor/include/arbor/simd/simd.hpp +++ b/arbor/include/arbor/simd/simd.hpp @@ -108,11 +108,6 @@ detail::simd_impl fma(detail::simd_impl a, detail::simd_impl b, detail: return detail::simd_impl::wrap(T::fma(a.value_, b.value_, c.value_)); } -template -detail::simd_impl fms(detail::simd_impl a, detail::simd_impl b, detail::simd_impl c) { - return detail::simd_impl::wrap(T::fms(a.value_, b.value_, c.value_)); -} - namespace detail { /// Indirect Expressions template @@ -579,12 +574,7 @@ namespace detail { return simd_impl::wrap(Impl::fma(a.value_, b.value_, c.value_)); } - friend simd_impl fms(simd_impl a, simd_impl b, simd_impl c) { - return simd_impl::wrap(Impl::fms(a.value_, b.value_, c.value_)); - } - // Lane-wise relational operations. - friend simd_mask operator==(const simd_impl& a, const simd_impl& b) { return simd_impl::mask(Impl::cmp_eq(a.value_, b.value_)); } @@ -703,12 +693,7 @@ namespace detail { template friend simd_impl arb::simd::fma(simd_impl a, simd_impl b, simd_impl c); - template - friend simd_impl arb::simd::fms(simd_impl a, simd_impl b, simd_impl c); - - // Declare Indirect/Indirect indexed/Where Expression copy function as friends - template friend void compound_indexed_add(const simd_impl& s, V* p, const simd_impl& index, unsigned width, index_constraint constraint); diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index b662dee246..226ae7fb11 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -188,13 +188,6 @@ void CExprEmitter::visit(BinaryExpression* e) { return; } } - // there is no FMS in C++ - // else if (e->op() == tok::minus) { - // if (auto r = e->rhs()->is_binary(); r && r->op() == tok::times) { - // emit_as_call("fms", r->lhs(), r->rhs(), lhs); - // return; - // } - // } if (e->is_infix()) { associativityKind assoc = Lexer::operator_associativity(e->op()); @@ -379,8 +372,6 @@ void SimdExprEmitter::visit(BinaryExpression* e) { "CExprEmitter: unsupported binary operator " + token_string(e->op()), e->location()); } - std::string rhs_name, lhs_name, rhs_pfxd, lhs_pfxd; - auto rhs = e->rhs(); auto lhs = e->lhs(); @@ -397,14 +388,8 @@ void SimdExprEmitter::visit(BinaryExpression* e) { return; } } - else if (e->op() == tok::minus) { - if (auto r = rhs->is_binary(); r && r->op() == tok::times) { - emit_fused("S::fms", r->lhs(), r->rhs(), lhs); - return; - } - } - + std::string rhs_name, lhs_name, rhs_pfxd, lhs_pfxd; if (auto id = rhs->is_identifier()) { rhs_name = id->name(); rhs_pfxd = id_prefix(id); From 54d3976d62ca8ae3402607d8fcf2e50a0c1e1515 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:04:53 +0100 Subject: [PATCH 11/15] clean-up --- modcc/printer/cexpr_emit.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modcc/printer/cexpr_emit.cpp b/modcc/printer/cexpr_emit.cpp index 226ae7fb11..5062c410d7 100644 --- a/modcc/printer/cexpr_emit.cpp +++ b/modcc/printer/cexpr_emit.cpp @@ -245,9 +245,7 @@ void CExprEmitter::visit(IfExpression* e) { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// std::unordered_set SimdExprEmitter::mask_names_; -void SimdExprEmitter::visit(NumberExpression* e) { - out_ << " (double)" << as_c_double(e->value()); -} +void SimdExprEmitter::visit(NumberExpression* e) { out_ << "(double)" << as_c_double(e->value()); } void SimdExprEmitter::visit(UnaryExpression* e) { static std::unordered_map unaryop_tbl = { From b20dca27a068b99935c1faaf1f39eb76b67f507f Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:56:35 +0200 Subject: [PATCH 12/15] fix merge --- arbor/include/arbor/simd/neon.hpp | 4 ---- modcc/printer/cprinter.cpp | 29 ++++------------------------- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/arbor/include/arbor/simd/neon.hpp b/arbor/include/arbor/simd/neon.hpp index 9319b0a31a..0a0bf0387b 100644 --- a/arbor/include/arbor/simd/neon.hpp +++ b/arbor/include/arbor/simd/neon.hpp @@ -235,10 +235,6 @@ struct neon_double2 : implbase { static float64x2_t div(const float64x2_t& a, const float64x2_t& b) { return vdivq_f64(a, b); } - static float64x2_t fma(const float64x2_t& a, const float64x2_t& b, const float64x2_t& c) { - return vfmaq_f64(c, a, b); - } - static float64x2_t logical_not(const float64x2_t& a) { return vreinterpretq_f64_u32(vmvnq_u32(vreinterpretq_u32_f64(a))); } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 0cae8a71b1..ac2b9c7230 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -1006,16 +1006,6 @@ void emit_simd_for_loop_per_constraint(std::ostream& out, BlockExpression* body, simd_expr_constraint constraint, const ApiFlags& flags) { ENTER(out); - out << fmt::format("for (auto i_ = 0ul; i_ < {0}index_constraints_n_{1}; i_++) {{\n" - " arb_index_type index_ = {0}index_constraints_{1}[i_];\n", - pp_var_pfx, - underlying_constraint_name) - << indent - << fmt::format("simd_value w_;\n" - "assign(w_, indirect(({}weight+index_), simd_width_));\n", - pp_var_pfx); - - ENTER(out); if (constraint == simd_expr_constraint::contiguous) { out << fmt::format("for (auto i_ = 0ul; i_ < {0}index_constraints_n_{1}; i_ += 2) {{\n", @@ -1064,24 +1054,13 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, << "assert(simd_width_ <= (unsigned)S::width(simd_cast(0)));\n"; if (!indices.empty()) { //Generate for loop for all contiguous simd_vectors - simd_expr_constraint constraint = simd_expr_constraint::contiguous; - std::string underlying_constraint = "contiguous"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); - + emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, simd_expr_constraint::contiguous, flags); //Generate for loop for all independent simd_vectors - constraint = simd_expr_constraint::independent; - underlying_constraint = "independent"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); - + emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, simd_expr_constraint::independent, flags); //Generate for loop for all simd_vectors that have no optimizing constraints - constraint = simd_expr_constraint::none; - underlying_constraint = "none"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); - + emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, simd_expr_constraint::none, flags); //Generate for loop for all constant simd_vectors - constraint = simd_expr_constraint::constant; - underlying_constraint = "constant"; - emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, constraint, underlying_constraint, flags); + emit_simd_for_loop_per_constraint(out, body, indexed_vars, scalars, indices, simd_expr_constraint::constant, flags); } else { // We may nonetheless need to read a global scalar indexed variable. From e0f981b45abae396a4924cc99db8aa26f052ce5c Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:10:13 +0200 Subject: [PATCH 13/15] Fix typo --- modcc/symdiff.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp index c408cd6f90..aaa531a5a3 100644 --- a/modcc/symdiff.cpp +++ b/modcc/symdiff.cpp @@ -535,7 +535,7 @@ class ConstantSimplifyVisitor: public Visitor { result_ = make_expression(loc, std::move(rhs)); } - else if (expr_value(lhs) == -1.0) { + else if (expr_value(rhs) == -1.0) { result_ = make_expression(loc, std::move(lhs)); } // -a * -b = a * b From 3930abf11ecc1c0882d99863214714ad4a82cd3a Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:11:25 +0200 Subject: [PATCH 14/15] revert exposing fms --- arbor/include/arbor/simd/avx.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arbor/include/arbor/simd/avx.hpp b/arbor/include/arbor/simd/avx.hpp index 320db3a920..aa88087d25 100644 --- a/arbor/include/arbor/simd/avx.hpp +++ b/arbor/include/arbor/simd/avx.hpp @@ -897,10 +897,6 @@ struct avx2_double4: avx_double4 { ifelse(is_small, broadcast(-HUGE_VAL), r))); } - - static __m256d fms(const __m256d& a, const __m256d& b, const __m256d& c) { - return _mm256_fmsub_pd(a, b, c); - } protected: static __m128i lo_epi32(__m256i a) { @@ -933,6 +929,10 @@ struct avx2_double4: avx_double4 { return fma(x, horner1(x, tail...), broadcast(a0)); } + static __m256d fms(const __m256d& a, const __m256d& b, const __m256d& c) { + return _mm256_fmsub_pd(a, b, c); + } + // Compute 2.0^n. // Overrides avx_double4::exp2int. static __m256d exp2int(__m128i n) { From b790b7a7dbc6cd130648b43853dabf2320a2855a Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:14:17 +0200 Subject: [PATCH 15/15] whitespace --- arbor/include/arbor/simd/avx.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arbor/include/arbor/simd/avx.hpp b/arbor/include/arbor/simd/avx.hpp index aa88087d25..89bd824146 100644 --- a/arbor/include/arbor/simd/avx.hpp +++ b/arbor/include/arbor/simd/avx.hpp @@ -409,7 +409,7 @@ struct avx_double4: implbase { // e^g = 1 + 2·g·P(g^2) / (Q(g^2)-g·P(g^2)). // // Note that the coefficients for R are close to but not the same as those - // from the 6,6 Padé approximant to the exponential. + // from the 6,6 Padé approximant to the exponential. // // The exponents n and g are calculated by: // @@ -420,7 +420,7 @@ struct avx_double4: implbase { // // |g| = |x - n·ln(2)| // = |x - x + α·ln(2)| - // + // // for some fraction |α| ≤ 0.5, and thus |g| ≤ 0.5ln(2) ≈ 0.347. // // Tne subtraction x - n·ln(2) is performed in two parts, with @@ -897,7 +897,7 @@ struct avx2_double4: avx_double4 { ifelse(is_small, broadcast(-HUGE_VAL), r))); } - + protected: static __m128i lo_epi32(__m256i a) { a = _mm256_shuffle_epi32(a, 0x08); @@ -932,7 +932,7 @@ struct avx2_double4: avx_double4 { static __m256d fms(const __m256d& a, const __m256d& b, const __m256d& c) { return _mm256_fmsub_pd(a, b, c); } - + // Compute 2.0^n. // Overrides avx_double4::exp2int. static __m256d exp2int(__m128i n) {