Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 18 additions & 69 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2307,11 +2307,11 @@ bool CodeGen_ARM::codegen_pairwise_vector_reduce(const VectorReduce *op, const E
bool CodeGen_ARM::codegen_across_vector_reduce(const VectorReduce *op, const Expr &init) {
if (target_vscale() == 0) {
// Leave this to vanilla codegen to emit "llvm.vector.reduce." intrinsic,
// which doesn't support scalable vector in LLVM 14
return false;
}

if (op->op != VectorReduce::Add &&
op->op != VectorReduce::Mul &&
op->op != VectorReduce::Max &&
op->op != VectorReduce::Min) {
return false;
Expand All @@ -2321,88 +2321,37 @@ bool CodeGen_ARM::codegen_across_vector_reduce(const VectorReduce *op, const Exp
const int output_lanes = op->type.lanes();
const int native_lanes = target.natural_vector_size(op->type);
const int input_lanes = val.type().lanes();
const int input_bits = op->type.bits();
Type elt = op->type.element_of();

if (output_lanes != 1 || input_lanes < 2) {
return false;
}

Expr (*binop)(Expr, Expr) = nullptr;
std::string op_name;
switch (op->op) {
case VectorReduce::Add:
binop = Add::make;
op_name = "add";
break;
case VectorReduce::Min:
binop = Min::make;
op_name = "min";
break;
case VectorReduce::Max:
binop = Max::make;
op_name = "max";
break;
default:
internal_error << "unreachable";
}

if (input_lanes == native_lanes) {
std::stringstream name; // e.g. llvm.aarch64.sve.sminv.nxv4i32
name << "llvm.aarch64.sve."
<< (op->type.is_float() ? "f" : op->type.is_int() ? "s" :
"u")
<< op_name << "v"
<< ".nxv" << (native_lanes / target_vscale()) << (op->type.is_float() ? "f" : "i") << input_bits;

// Integer add accumulation output is 64 bit only
const bool type_upgraded = op->op == VectorReduce::Add && op->type.is_int_or_uint();
const int output_bits = type_upgraded ? 64 : input_bits;
Type intrin_ret_type = op->type.with_bits(output_bits);

const string intrin_name = name.str();

Expr pred = const_true(native_lanes);
vector<Expr> args{pred, op->value};

// Make sure the declaration exists, or the codegen for
// call will assume that the args should scalarize.
if (!module->getFunction(intrin_name)) {
vector<llvm::Type *> arg_types;
arg_types.reserve(args.size());
for (const Expr &e : args) {
arg_types.push_back(llvm_type_with_constraint(e.type(), false, VectorTypeConstraint::VScale));
}
FunctionType *func_t = FunctionType::get(llvm_type_with_constraint(intrin_ret_type, false, VectorTypeConstraint::VScale),
arg_types, false);
llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get());
}

Expr equiv = Call::make(intrin_ret_type, intrin_name, args, Call::PureExtern);
if (type_upgraded) {
equiv = Cast::make(op->type, equiv);
}
if (init.defined()) {
equiv = binop(init, equiv);
}
equiv = common_subexpression_elimination(equiv);
equiv.accept(this);
return true;

} else if (input_lanes < native_lanes) {
// Create equivalent where lanes==native_lanes by padding data which doesn't affect the result
if (input_lanes % native_lanes == 0) {
// Leave this to vanilla codegen to emit "llvm.vector.reduce." intrinsic
return false;
} else {
// Create equivalent where lanes==native_lanes*n by padding data which doesn't affect the result
Expr padding;
const int inactive_lanes = native_lanes - input_lanes;
const int inactive_lanes = align_up(input_lanes, native_lanes) - input_lanes;
Expr (*binop)(Expr, Expr) = nullptr;
Type elt = op->type.element_of();

switch (op->op) {
case VectorReduce::Add:
padding = make_zero(elt.with_lanes(inactive_lanes));
binop = Add::make;
break;
case VectorReduce::Mul:
padding = make_one(elt.with_lanes(inactive_lanes));
binop = Mul::make;
break;
case VectorReduce::Min:
padding = elt.with_lanes(inactive_lanes).min();
padding = elt.with_lanes(inactive_lanes).max();
binop = Min::make;
break;
case VectorReduce::Max:
padding = elt.with_lanes(inactive_lanes).max();
padding = elt.with_lanes(inactive_lanes).min();
binop = Max::make;
break;
default:
internal_error << "unreachable";
Expand Down
12 changes: 4 additions & 8 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4283,6 +4283,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini
if (output_lanes == 1) {
const int input_lanes = val.type().lanes();
const int input_bytes = input_lanes * val.type().bytes();
const int vscale = std::max(effective_vscale, 1);
const bool llvm_has_intrinsic =
// Must be one of these ops
((op->op == VectorReduce::Add ||
Expand All @@ -4291,20 +4292,15 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini
op->op == VectorReduce::Max) &&
(use_llvm_vp_intrinsics ||
// Must be a power of two lanes
((input_lanes >= 2) &&
((input_lanes >= 2 * vscale) &&
((input_lanes & (input_lanes - 1)) == 0) &&
// int versions exist up to 1024 bits
((!op->type.is_float() && input_bytes <= 1024) ||
// float versions exist up to 16 lanes
input_lanes <= 16) &&
// As of the release of llvm 10, the 64-bit experimental total
// reductions don't seem to be done yet on arm.
(val.type().bits() != 64 ||
target.arch != Target::ARM))));
input_lanes <= 16 * vscale))));

if (llvm_has_intrinsic) {
const char *name = "<err>";
const int bits = op->type.bits();
bool takes_initial_value = use_llvm_vp_intrinsics;
Expr initial_value = init;
if (op->type.is_float()) {
Expand Down Expand Up @@ -4384,7 +4380,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini
std::stringstream build_name;
build_name << "llvm.vector.reduce.";
build_name << name;
build_name << ".v" << val.type().lanes() << (op->type.is_float() ? 'f' : 'i') << bits;
build_name << mangle_llvm_type(get_vector_type(llvm_type_of(elt), val.type().lanes()));

string intrin_name = build_name.str();

Expand Down
Loading