diff --git a/apps/bgu/Makefile b/apps/bgu/Makefile index 8eb687ec064a..a75b623cfcc1 100644 --- a/apps/bgu/Makefile +++ b/apps/bgu/Makefile @@ -16,11 +16,11 @@ $(GENERATOR_BIN)/bgu.generator: bgu_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/bgu.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu -o $(BIN)/$* target=$*-no_runtime + $< -g bgu -f bgu -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime $(BIN)/%/bgu_auto_schedule.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 + $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index b24626ad66a1..7114ad360135 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -36,7 +36,7 @@ class FindBuffers : public IRGraphVisitor { void visit(const For *op) override { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); bool old = in_device_loop; if (op->device_api != DeviceAPI::None && op->device_api != DeviceAPI::Host) { diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index 22c3425c02a4..ddb9bc1098c5 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -22,7 +22,7 @@ vector apply_split(const Split &split, const string &prefix, Expr inner = Variable::make(Int(32), prefix + split.inner); Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); - Expr old_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); + Expr old_extent = (old_max - old_min) + 1; dim_extent_alignment[split.inner] = split.factor; @@ -135,10 +135,10 @@ vector apply_split(const Split &split, const string &prefix, // Define the inner and outer in terms of the fused var Expr fused = Variable::make(Int(32), prefix + split.old_var); Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min"); + Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max"); Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min"); - Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent"); - const Expr &factor = inner_extent; + const Expr &factor = (inner_max - inner_min) + 1; Expr inner = fused % factor + inner_min; Expr outer = fused / factor + outer_min; @@ -169,7 +169,6 @@ vector> compute_loop_bounds_after_split(const Split &spl // Define the bounds on the split dimensions using the bounds on the function args. vector> let_stmts; - Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); switch (split.split_type) { @@ -178,24 +177,22 @@ vector> compute_loop_bounds_after_split(const Split &spl Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor; let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0); let_stmts.emplace_back(prefix + split.inner + ".loop_max", inner_extent - 1); - let_stmts.emplace_back(prefix + split.inner + ".loop_extent", inner_extent); let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0); let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1); - let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent); } break; case Split::FuseVars: { // Define bounds on the fused var using the bounds on the inner and outer - Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent"); - Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent"); - Expr fused_extent = inner_extent * outer_extent; + Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min"); + Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max"); + Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min"); + Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max"); + Expr fused_extent = (inner_max - inner_min + 1) * (outer_max - outer_min + 1); let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0); let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1); - let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent); } break; case Split::RenameVar: let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min); let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max); - let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent); break; } diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index bb5e4279d367..536327ebba3f 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator { if (is_no_op(body)) { return body; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -752,11 +752,10 @@ class InjectRingBuffering : public IRMutator { struct Loop { std::string name; - Expr min; Expr extent; - Loop(std::string n, Expr m, Expr e) - : name(std::move(n)), min(std::move(m)), extent(std::move(e)) { + Loop(std::string n, Expr e) + : name(std::move(n)), extent(std::move(e)) { } }; @@ -778,8 +777,7 @@ class InjectRingBuffering : public IRMutator { int loop_index = hoist_storage_loop_index[op->name] + 1; Expr current_index = Variable::make(Int(32), loops[loop_index].name); while (++loop_index < (int)loops.size()) { - current_index = current_index * - (loops[loop_index].extent - loops[loop_index].min) + + current_index = current_index * loops[loop_index].extent + Variable::make(Int(32), loops[loop_index].name); } current_index = current_index % f.schedule().ring_buffer(); @@ -817,7 +815,7 @@ class InjectRingBuffering : public IRMutator { } Stmt visit(const For *op) override { - loops.emplace_back(op->name, op->min, op->extent); + loops.emplace_back(op->name, op->extent()); Stmt mutated = IRMutator::visit(op); loops.pop_back(); return mutated; diff --git a/src/BoundConstantExtentLoops.cpp b/src/BoundConstantExtentLoops.cpp index fc2fea5b9d41..1312508f8ddb 100644 --- a/src/BoundConstantExtentLoops.cpp +++ b/src/BoundConstantExtentLoops.cpp @@ -46,7 +46,8 @@ class BoundLoops : public IRMutator { } Stmt visit(const For *op) override { - if (is_const(op->extent)) { + Expr extent = simplify(op->extent()); + if (is_const(extent)) { // Nothing needs to be done return IRMutator::visit(op); } @@ -54,7 +55,6 @@ class BoundLoops : public IRMutator { if (op->for_type == ForType::Unrolled || op->for_type == ForType::Vectorized) { // Give it one last chance to simplify to an int - Expr extent = simplify(op->extent); Stmt body = op->body; const IntImm *e = extent.as(); @@ -82,8 +82,8 @@ class BoundLoops : public IRMutator { if (extent_upper.defined()) { e = extent_upper.as(); body = - IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) < - op->min + op->extent), + IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <= + op->max), body); } } @@ -93,7 +93,7 @@ class BoundLoops : public IRMutator { // to a serial loop. user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n"; body = mutate(body); - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, ForType::Serial, op->partition_policy, op->device_api, std::move(body)); } @@ -103,7 +103,7 @@ class BoundLoops : public IRMutator { << "Loop over " << op->name << " has extent " << extent << ".\n"; body = mutate(body); - return For::make(op->name, op->min, e, + return For::make(op->name, op->min, (op->min + e) - 1, op->for_type, op->partition_policy, op->device_api, std::move(body)); } else { return IRMutator::visit(op); diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index 80f58889448c..f3347c0f47fd 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -59,7 +59,7 @@ class BoundSmallAllocations : public IRMutator { Stmt visit(const For *op) override { Interval min_bounds = find_constant_bounds(op->min, scope); - Interval max_bounds = find_constant_bounds(op->min + op->extent - 1, scope); + Interval max_bounds = find_constant_bounds(op->max, scope); Interval b = Interval::make_union(min_bounds, max_bounds); b.min = simplify(b.min); b.max = simplify(b.max); diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 670ccf11d177..1218eb50239a 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -2964,23 +2964,11 @@ class BoxesTouched : public IRGraphVisitor { TRACK_BOXES_TOUCHED_INFO("var:", op->name); if (consider_calls) { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); } - Expr min_val, max_val; - if (const Interval *in = scope.find(op->name + ".loop_min")) { - min_val = in->min; - } else { - min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; - } - - if (const Interval *in = scope.find(op->name + ".loop_max")) { - max_val = in->max; - } else { - max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max; - max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max; - max_val -= 1; - } + Expr min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; + Expr max_val = bounds_of_expr_in_scope(op->max, scope, func_bounds).max; push_var(op->name); { @@ -3819,7 +3807,7 @@ void bounds_test() { Buffer in(10); in.set_name("input"); - Stmt loop = For::make("x", 3, 10, ForType::Serial, Partition::Auto, DeviceAPI::Host, + Stmt loop = For::make("x", 3, 12, ForType::Serial, Partition::Auto, DeviceAPI::Host, Provide::make("output", {Add::make(Call::make(in, input_site_1), Call::make(in, input_site_2))}, diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 9ba1f1af2019..6e8c2f2a5241 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -114,10 +114,7 @@ class BoundsOfInnerVar : public IRVisitor { } void visit(const For *op) override { - // At this stage of lowering, loop_min and loop_max - // conveniently exist in scope. - Interval in(Variable::make(Int(32), op->name + ".loop_min"), - Variable::make(Int(32), op->name + ".loop_max")); + Interval in(op->min, op->max); if (op->name == var) { result = in; @@ -1308,7 +1305,7 @@ class BoundsInference : public IRMutator { } } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } Scope<> let_vars_in_scope; @@ -1392,7 +1389,7 @@ Stmt bounds_inference(Stmt s, s = Block::make(Evaluate::make(marker), s); // Add a synthetic outermost loop to act as 'root'. - s = For::make("", 0, 1, ForType::Serial, Partition::Never, DeviceAPI::None, s); + s = For::make("", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s); s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups, outputs, func_bounds, target) diff --git a/src/CanonicalizeGPUVars.cpp b/src/CanonicalizeGPUVars.cpp index 4e70af965138..7ca9b7c4fbf5 100644 --- a/src/CanonicalizeGPUVars.cpp +++ b/src/CanonicalizeGPUVars.cpp @@ -90,22 +90,10 @@ class CanonicalizeGPUVars : public IRMutator { return name; } - std::string canonicalize_let(const std::string &name) { - if (ends_with(name, ".loop_max")) { - return find_replacement(".loop_max", name); - } else if (ends_with(name, ".loop_min")) { - return find_replacement(".loop_min", name); - } else if (ends_with(name, ".loop_extent")) { - return find_replacement(".loop_extent", name); - } else { - return name; - } - } - Stmt visit(const For *op) override { std::string name = op->name; Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = mutate(op->body); if ((op->for_type == ForType::GPUBlock) || @@ -130,44 +118,21 @@ class CanonicalizeGPUVars : public IRMutator { gpu_vars.emplace(op->name, name); Expr new_var = Variable::make(Int(32), name); min = substitute(op->name, new_var, min); - extent = substitute(op->name, new_var, extent); + max = substitute(op->name, new_var, max); body = substitute(op->name, new_var, body); } } if ((name == op->name) && min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } else { - return For::make(name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } - Stmt visit(const LetStmt *op) override { - vector> lets; - Stmt result; - - do { - lets.emplace_back(op->name, mutate(op->value)); - result = op->body; - } while ((op = op->body.as())); - - result = mutate(result); - - for (const auto &[var, value] : reverse_view(lets)) { - std::string name = canonicalize_let(var); - if (name != var) { - Expr new_var = Variable::make(Int(32), name); - result = substitute(var, new_var, result); - } - result = LetStmt::make(name, value, result); - } - - return result; - } - Stmt visit(const IfThenElse *op) override { Expr condition = mutate(op->condition); diff --git a/src/Closure.cpp b/src/Closure.cpp index 5c5125a9b291..de1564526462 100644 --- a/src/Closure.cpp +++ b/src/Closure.cpp @@ -38,7 +38,7 @@ void Closure::visit(const LetStmt *op) { void Closure::visit(const For *op) { ScopedBinding<> p(ignore, op->name); op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); op->body.accept(this); } diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 44756204745a..a5dc3298be63 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2223,7 +2223,7 @@ void CodeGen_C::visit(const Atomic *op) { void CodeGen_C::visit(const For *op) { string id_min = print_expr(op->min); - string id_extent = print_expr(op->extent); + string id_max = print_expr(op->max); if (op->for_type == ForType::Parallel) { stream << get_indent() << "#pragma omp parallel for\n"; @@ -2237,8 +2237,7 @@ void CodeGen_C::visit(const For *op) { << " = " << id_min << "; " << print_name(op->name) - << " < " << id_min - << " + " << id_extent + << " <= " << id_max << "; " << print_name(op->name) << "++)\n"; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index ad4f6451f918..4ce641e680ad 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1179,7 +1179,8 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, // generation time, emit code such that it can be patched some point // later when calling D3DCompile() / halide_d3d12compute_run() numthreads[index] = 0; // <-- 0 indicates 'undetermined' - const IntImm *int_limit = loop->extent.as(); + Expr extent = simplify(loop->extent()); + const IntImm *int_limit = extent.as(); if (nullptr != int_limit) { numthreads[index] = int_limit->value; user_assert(numthreads[index] > 0) << "For D3D12Compute, 'numthreads[" << index << "]' values must be greater than zero.\n"; diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 7bf29bd15aea..05b68447b6a4 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -363,7 +363,7 @@ class InjectHVXLocks : public IRMutator { if (uses_hvx) { body = acquire_hvx_context(body, target); body = substitute("uses_hvx", true, body); - Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type, + Stmt new_for = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); Stmt prolog = IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock()); @@ -408,7 +408,7 @@ class InjectHVXLocks : public IRMutator { // vector code // halide_qurt_unlock // } - s = For::make(op->name, op->min, op->extent, op->for_type, + s = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 7764c3d61264..2d33dd525039 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3722,7 +3722,7 @@ void CodeGen_LLVM::visit(const ProducerConsumer *op) { void CodeGen_LLVM::visit(const For *op) { Value *min = codegen(op->min); - Value *extent = codegen(op->extent); + Value *max = codegen(op->max); const Acquire *acquire = op->body.as(); // TODO(zvookin): remove this after validating it doesn't happen @@ -3733,8 +3733,6 @@ void CodeGen_LLVM::visit(const For *op) { if (op->for_type == ForType::Serial) { - Value *max = builder->CreateNSWAdd(min, extent); - BasicBlock *preheader_bb = builder->GetInsertBlock(); // Make a new basic block for the loop @@ -3745,8 +3743,8 @@ void CodeGen_LLVM::visit(const For *op) { BasicBlock *after_bb = BasicBlock::Create( *context, std::to_string(for_loop_id) + std::string("_end_for_") + op->name, function); - // If min < max, fall through to the loop bb - Value *enter_condition = builder->CreateICmpSLT(min, max); + // If min <= max, fall through to the loop bb + Value *enter_condition = builder->CreateICmpSLE(min, max); builder->CreateCondBr(enter_condition, loop_bb, after_bb, very_likely_branch); builder->SetInsertPoint(loop_bb); @@ -3767,7 +3765,7 @@ void CodeGen_LLVM::visit(const For *op) { phi->addIncoming(next_var, builder->GetInsertBlock()); // Maybe exit the loop - Value *end_condition = builder->CreateICmpNE(next_var, max); + Value *end_condition = builder->CreateICmpSLE(next_var, max); builder->CreateCondBr(end_condition, loop_bb, after_bb); builder->SetInsertPoint(after_bb); diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index de2c9cd0dd13..671f923ec183 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -418,7 +418,8 @@ struct FindWorkGroupSize : public IRVisitor { // Save & validate the workgroup size int index = thread_loop_workgroup_index(loop->name); if (index >= 0) { - const IntImm *literal = loop->extent.as(); + Expr extent = simplify(loop->extent()); + const IntImm *literal = extent.as(); if (literal != nullptr) { uint32_t new_wg_size = literal->value; user_assert(workgroup_size[index] == 0 || workgroup_size[index] == new_wg_size) @@ -1683,7 +1684,7 @@ std::pair simt_intrinsic(const std::string &name) { } // anonymous namespace void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { - debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(For): name=" << op->name << " min=" << op->min << " extent=" << op->extent << "\n"; + debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(For): name=" << op->name << " min=" << op->min << " max=" << op->max << "\n"; if (is_gpu(op->for_type)) { // This should always be true at this point in codegen @@ -1710,24 +1711,22 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { } } else { - debug(2) << " (serial for loop): min=" << op->min << " extent=" << op->extent << "\n"; + debug(2) << " (serial for loop): min=" << op->min << " max=" << op->max << "\n"; internal_assert(op->for_type == ForType::Serial) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit unhandled For type: " << op->for_type << "\n"; - user_assert(op->min.type() == op->extent.type()); + user_assert(op->min.type() == op->max.type()); user_assert(op->min.type().is_int() || op->min.type().is_uint()); op->min.accept(this); SpvId min_id = builder.current_id(); - op->extent.accept(this); - SpvId extent_id = builder.current_id(); + op->max.accept(this); + SpvId max_id = builder.current_id(); // Compute max. Type index_type = op->min.type(); SpvId index_type_id = builder.declare_type(index_type); SpvStorageClass storage_class = SpvStorageClassFunction; SpvId index_var_type_id = builder.declare_pointer_type(index_type_id, storage_class); - SpvId max_id = builder.reserve_id(SpvResultId); - builder.append(SpvFactory::integer_add(index_type_id, max_id, min_id, extent_id)); // Declare loop var const std::string loop_var_name = unique_name(std::string("k") + std::to_string(kernel_index) + "_loop_idx"); @@ -1757,7 +1756,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { SpvId loop_test_type_id = builder.declare_type(Bool()); SpvId loop_test_id = builder.reserve_id(SpvResultId); builder.append(SpvFactory::load(index_type_id, loop_index_id, loop_var_id)); - builder.append(SpvFactory::integer_less_than(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_int())); + builder.append(SpvFactory::integer_less_than_equal(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_int())); builder.append(SpvFactory::conditional_branch(loop_test_id, body_block_id, merge_block_id)); } builder.leave_block(); diff --git a/src/CodeGen_WebGPU_Dev.cpp b/src/CodeGen_WebGPU_Dev.cpp index ea43b3cefbd5..d0a9310856ca 100644 --- a/src/CodeGen_WebGPU_Dev.cpp +++ b/src/CodeGen_WebGPU_Dev.cpp @@ -663,11 +663,11 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const For *loop) { << "Can only use serial loops inside WebGPU shaders\n"; string id_min = print_expr(loop->min); - string id_extent = print_expr(loop->extent); + string id_max = print_expr(loop->max); string id_counter = print_name(loop->name); stream << get_indent() << "for (var " << id_counter << " = " << id_min << "; " - << id_counter << " < " << id_min << " + " << id_extent << "; " + << id_counter << " <= " << id_max << "; " // TODO: Use increment statement when supported by Chromium. << id_counter << " = " << id_counter << " + 1)\n"; open_scope(); diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b6f49cb1bf43..1e3ee74f9f65 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -539,12 +539,12 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt) const auto *for_stmt = (const Serialize::For *)stmt; const auto name = deserialize_string(for_stmt->name()); const auto min = deserialize_expr(for_stmt->min_type(), for_stmt->min()); - const auto extent = deserialize_expr(for_stmt->extent_type(), for_stmt->extent()); + const auto max = deserialize_expr(for_stmt->max_type(), for_stmt->max()); const ForType for_type = deserialize_for_type(for_stmt->for_type()); const Partition partition_policy = deserialize_partition(for_stmt->partition_policy()); const DeviceAPI device_api = deserialize_device_api(for_stmt->device_api()); const auto body = deserialize_stmt(for_stmt->body_type(), for_stmt->body()); - return For::make(name, min, extent, for_type, partition_policy, device_api, body); + return For::make(name, min, max, for_type, partition_policy, device_api, body); } case Serialize::Stmt::Store: { const auto *store_stmt = (const Serialize::Store *)stmt; diff --git a/src/EarlyFree.cpp b/src/EarlyFree.cpp index 35de3c15cbcd..8b664c2bcf8d 100644 --- a/src/EarlyFree.cpp +++ b/src/EarlyFree.cpp @@ -30,7 +30,7 @@ class FindLastUse : public IRVisitor { void visit(const For *loop) override { loop->min.accept(this); - loop->extent.accept(this); + loop->max.accept(this); ScopedValue old_in_loop(in_loop, true); loop->body.accept(this); } diff --git a/src/Func.cpp b/src/Func.cpp index cf8904ec2d31..696b4353dc2b 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -708,15 +708,14 @@ pair project_rdom(const vector &dims, con for (const auto &[var, min, extent] : rdom.domain()) { add_let(bounds_projection, var + ".loop_min", min); add_let(bounds_projection, var + ".loop_max", min + extent - 1); - add_let(bounds_projection, var + ".loop_extent", extent); } // Build the new RDom from the bounds_projection. vector new_rvars; for (const Dim &dim : dims) { const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min")); - const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent")); - new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, new_extent}); + const Expr new_max = simplify(bounds_projection.at(dim.var + ".loop_max")); + new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, (new_max - new_min) + 1}); } ReductionDomain new_rdom{new_rvars}; new_rdom.where(rdom.predicate()); diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 1b7506d96c9b..88f9a542550f 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -39,7 +39,7 @@ class ExtractBlockSize : public IRVisitor { void found_thread_for(int dim, const string &name, const Expr &extent) { internal_assert(dim >= 0 && dim < 3); if (!block_extent[dim].defined()) { - block_extent[dim] = extent; + block_extent[dim] = simplify(extent); } else { block_extent[dim] = simplify(Max::make(extent, block_extent[dim])); } @@ -55,16 +55,16 @@ class ExtractBlockSize : public IRVisitor { void visit(const For *op) override { for (int i = 0; i < 3; i++) { if (ends_with(op->name, gpu_thread_name(i))) { - found_thread_for(i, op->name, op->extent); + found_thread_for(i, op->name, op->extent()); } else if (ends_with(op->name, gpu_block_name(i))) { - found_block_for(i, op->name, op->extent); + found_block_for(i, op->name, op->extent()); } } IRVisitor::visit(op); Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); // For non-rectangular thread loops, use a bounding box. We'll inject if statements later. for (Expr &e : block_extent) { if (e.defined() && expr_uses_var(e, op->name)) { @@ -141,7 +141,7 @@ class NormalizeDimensionality : public IRMutator { return s; } while (max_depth < block_size.threads_dimensions()) { - s = For::make(gpu_thread_name(max_depth), 0, 1, ForType::GPUThread, + s = For::make(gpu_thread_name(max_depth), 0, 0, ForType::GPUThread, Partition::Never, device_api, s); max_depth++; } @@ -205,10 +205,10 @@ class ReplaceForWithIf : public IRMutator { Expr var = Variable::make(Int(32), gpu_thread_name(dim)); body = substitute(op->name, var + op->min, body); - if (equal(op->extent, block_size.num_threads(dim))) { + if (can_prove(op->extent() == block_size.num_threads(dim))) { return body; } else { - Expr cond = var < op->extent; + Expr cond = var <= op->max; return IfThenElse::make(cond, body, Stmt()); } } else { @@ -340,7 +340,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { // Expand any new shared allocations found in the body using the loop bounds. Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); for (SharedAllocation &s : allocations) { // If the size depends on the loop variable, take the max // over all loop iterations @@ -366,7 +366,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { precompute_allocation_size(s); break; case Monotonic::Increasing: - s.size = substitute(op->name, simplify(op->min + op->extent - 1), s.size); + s.size = substitute(op->name, op->max, s.size); break; case Monotonic::Constant: // The size expression used the variable, but we @@ -381,7 +381,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { } if (in_threads && op->is_parallel()) { // For parallel inner loops, make a separate slice per loop iteration - s.size *= op->extent; + s.size *= op->extent(); } } @@ -393,13 +393,13 @@ class ExtractSharedAndHeapAllocations : public IRMutator { } Expr new_min = mutate(op->min); - Expr new_extent = mutate(op->extent); + Expr new_max = mutate(op->max); if (host_side_preamble.defined()) { string loop_name = unique_name('t'); Expr v = Variable::make(Int(32), loop_name); host_side_preamble = substitute(op->name, v, host_side_preamble); - host_side_preamble = For::make(loop_name, new_min, new_extent, + host_side_preamble = For::make(loop_name, new_min, new_max, ForType::Serial, Partition::Never, DeviceAPI::None, host_side_preamble); if (old_preamble.defined()) { host_side_preamble = Block::make(old_preamble, host_side_preamble); @@ -408,7 +408,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { host_side_preamble = old_preamble; } - return For::make(op->name, new_min, new_extent, + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, body); } @@ -1082,7 +1082,7 @@ class ExtractRegisterAllocations : public IRMutator { // Expand any new register allocations found in the body using the loop bounds. Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); // Expand the inner allocations using the loop bounds. for (RegisterAllocation &s : allocations) { @@ -1098,7 +1098,7 @@ class ExtractRegisterAllocations : public IRMutator { allocations.swap(old); } - return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, mutate(op->min), mutate(op->max), op->for_type, op->partition_policy, op->device_api, body); } } @@ -1258,7 +1258,7 @@ class InjectThreadBarriers : public IRMutator { // synchronizations within the block body = Block::make(body, make_barrier(0)); } - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } else { return IRMutator::visit(op); @@ -1410,13 +1410,13 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { string thread_id = gpu_thread_name(0); // Add back in any register-level allocations body = register_allocs.rewrap(body, thread_id); - body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->partition_policy, op->device_api, body); + body = For::make(thread_id, 0, block_size_x - 1, innermost_loop_type, op->partition_policy, op->device_api, body); // Rewrap the whole thing in other loops over threads for (int i = 1; i < block_size.threads_dimensions(); i++) { thread_id = gpu_thread_name(i); body = register_allocs.rewrap(body, thread_id); - body = For::make(thread_id, 0, block_size.num_threads(i), + body = For::make(thread_id, 0, block_size.num_threads(i) - 1, ForType::GPUThread, op->partition_policy, op->device_api, body); } thread_id.clear(); @@ -1433,7 +1433,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } else { return IRMutator::visit(op); @@ -1503,7 +1503,7 @@ class ZeroGPULoopMins : public IRMutator { internal_assert(op); Expr adjusted = Variable::make(Int(32), op->name) + op->min; Stmt body = substitute(op->name, adjusted, op->body); - stmt = For::make(op->name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, 0, simplify(op->max - op->min), op->for_type, op->partition_policy, op->device_api, body); } return stmt; } @@ -1547,7 +1547,7 @@ class AddConditionToALoop : public IRMutator { return IRMutator::visit(op); } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, IfThenElse::make(condition, op->body, Stmt())); } diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index a7a305b5902e..e540a9697513 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -4,6 +4,7 @@ #include "Closure.h" #include "Elf.h" #include "HexagonOffload.h" +#include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "InjectHostDevBufferCopies.h" @@ -745,10 +746,10 @@ class InjectHexagonRpc : public IRMutator { // After moving this to Hexagon, it doesn't need to be marked // Hexagon anymore. Stmt body; - if (is_const_one(loop->extent)) { + if (equal(loop->min, loop->max)) { body = LetStmt::make(loop->name, loop->min, loop->body); } else { - body = For::make(loop->name, loop->min, loop->extent, loop->for_type, loop->partition_policy, + body = For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, DeviceAPI::None, loop->body); } diff --git a/src/IR.cpp b/src/IR.cpp index e17818a8cc9e..c844c672656a 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -342,20 +342,20 @@ Stmt ProducerConsumer::make_consume(const std::string &name, Stmt body) { } Stmt For::make(const std::string &name, - Expr min, Expr extent, + Expr min, Expr max, ForType for_type, Partition partition_policy, DeviceAPI device_api, Stmt body) { internal_assert(min.defined()) << "For of undefined\n"; - internal_assert(extent.defined()) << "For of undefined\n"; + internal_assert(max.defined()) << "For of undefined\n"; internal_assert(min.type() == Int(32)) << "For with non-integer min\n"; - internal_assert(extent.type() == Int(32)) << "For with non-integer extent\n"; + internal_assert(max.type() == Int(32)) << "For with non-integer max\n"; internal_assert(body.defined()) << "For of undefined\n"; For *node = new For; node->name = name; node->min = std::move(min); - node->extent = std::move(extent); + node->max = std::move(max); node->for_type = for_type; node->partition_policy = partition_policy; node->device_api = device_api; diff --git a/src/IR.h b/src/IR.h index bdf42a75f7b1..53ae316a404b 100644 --- a/src/IR.h +++ b/src/IR.h @@ -833,28 +833,26 @@ struct Variable : public ExprNode { static const IRNodeType _node_type = IRNodeType::Variable; }; -/** A for loop. Execute the 'body' statement for all values of the - * variable 'name' from 'min' to 'min + extent'. There are four - * types of For nodes. A 'Serial' for loop is a conventional - * one. In a 'Parallel' for loop, each iteration of the loop - * happens in parallel or in some unspecified order. In a - * 'Vectorized' for loop, each iteration maps to one SIMD lane, - * and the whole loop is executed in one shot. For this case, - * 'extent' must be some small integer constant (probably 4, 8, or - * 16). An 'Unrolled' for loop compiles to a completely unrolled - * version of the loop. Each iteration becomes its own - * statement. Again in this case, 'extent' should be a small - * integer constant. */ +/** A for loop. Execute the 'body' statement for all values of the variable + * 'name' from 'min' to 'max' inclusive. There are four types of For nodes. A + * 'Serial' for loop is a conventional one. In a 'Parallel' for loop, each + * iteration of the loop happens in parallel or in some unspecified order. In a + * 'Vectorized' for loop, each iteration maps to one SIMD lane, and the whole + * loop is executed in one shot. For this case, the extent (max - min + 1) must + * be some small integer constant (probably 4, 8, or 16). An 'Unrolled' for loop + * compiles to a completely unrolled version of the loop. Each iteration becomes + * its own statement. Again in this case, the extent should be a small integer + * constant. */ struct For : public StmtNode { std::string name; - Expr min, extent; + Expr min, max; ForType for_type; DeviceAPI device_api; Stmt body; Partition partition_policy; static Stmt make(const std::string &name, - Expr min, Expr extent, + Expr min, Expr max, ForType for_type, Partition partition_policy, DeviceAPI device_api, Stmt body); @@ -866,6 +864,10 @@ struct For : public StmtNode { return Halide::Internal::is_parallel(for_type); } + Expr extent() const { + return Add::make(Sub::make(max, min), 1); + } + static const IRNodeType _node_type = IRNodeType::For; }; diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 91863370d3c0..1e7ae422c549 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -426,7 +426,7 @@ struct Comparer { cmp(&For::device_api); cmp(&For::partition_policy); cmp(&For::min); - cmp(&For::extent); + cmp(&For::max); cmp(&For::body); break; case IRNodeType::Acquire: diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index f0b861f1d5c0..9eecd0579840 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -202,14 +202,14 @@ Stmt IRMutator::visit(const ProducerConsumer *op) { Stmt IRMutator::visit(const For *op) { Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = mutate(op->body); if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } - return For::make(op->name, std::move(min), std::move(extent), + return For::make(op->name, std::move(min), std::move(max), op->for_type, op->partition_policy, op->device_api, std::move(body)); } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 97f6a409d6c9..e95286af03ee 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1140,7 +1140,7 @@ void IRPrinter::visit(const For *op) { stream << var(op->name) << paren(", "); print_no_parens(op->min); stream << paren(", "); - print_no_parens(op->extent); + print_no_parens(op->max); closef(); stream << " "; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index 9a5e6a8e0537..25fd7e608f27 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -167,7 +167,7 @@ void IRVisitor::visit(const ProducerConsumer *op) { void IRVisitor::visit(const For *op) { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); op->body.accept(this); } @@ -443,7 +443,7 @@ void IRGraphVisitor::visit(const ProducerConsumer *op) { void IRGraphVisitor::visit(const For *op) { include(op->min); - include(op->extent); + include(op->max); include(op->body); } diff --git a/src/LICM.cpp b/src/LICM.cpp index c73fcf5424ab..31f079175fad 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -318,7 +318,7 @@ class LICM : public IRMutator { const For *loop = new_stmt.as(); internal_assert(loop); - new_stmt = For::make(loop->name, loop->min, loop->extent, + new_stmt = For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, loop->device_api, mutate(loop->body)); // Wrap lets for the lifted invariants @@ -563,7 +563,7 @@ class HoistIfStatements : public IRMutator { if (!i->else_case.defined() && is_pure(i->condition) && !expr_uses_var(i->condition, op->name)) { - Stmt s = For::make(op->name, op->min, op->extent, + Stmt s = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, i->then_case); return IfThenElse::make(i->condition, s); } @@ -571,7 +571,7 @@ class HoistIfStatements : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5349e9c316f9..fccc049ce5b6 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -550,7 +550,7 @@ class LoopCarry : public IRMutator { } Stmt visit(const For *op) override { - if (op->for_type == ForType::Serial && !is_const_one(op->extent)) { + if (op->for_type == ForType::Serial && !equal(op->min, op->max)) { Stmt stmt; Stmt body = mutate(op->body); LoopCarryOverLoop carry(op->name, in_consume, max_carried_values); @@ -558,7 +558,7 @@ class LoopCarry : public IRMutator { if (body.same_as(op->body)) { stmt = op; } else { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } // Inject the scratch buffer allocations. @@ -567,7 +567,7 @@ class LoopCarry : public IRMutator { stmt = Allocate::make(alloc.name, alloc.type, MemoryType::Stack, {alloc.size}, const_true(), stmt); } if (!carry.allocs.empty()) { - stmt = IfThenElse::make(op->extent > 0, stmt); + stmt = IfThenElse::make(op->min <= op->max, stmt); } return stmt; } else { diff --git a/src/Lower.cpp b/src/Lower.cpp index c6db1adfa33c..dae55055be6f 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -236,7 +236,7 @@ void lower_impl(const vector &output_funcs, log("Lowering after uniquifying variable names:", s); debug(1) << "Simplifying...\n"; - s = simplify(s, false); // Storage folding and allocation bounds inference needs .loop_max symbols + s = simplify(s); log("Lowering after first simplification:", s); debug(1) << "Simplifying correlated differences...\n"; diff --git a/src/LowerParallelTasks.cpp b/src/LowerParallelTasks.cpp index d6ed27ca0905..62e909136841 100644 --- a/src/LowerParallelTasks.cpp +++ b/src/LowerParallelTasks.cpp @@ -270,9 +270,11 @@ struct LowerParallelTasks : public IRMutator { std::string loop_min_name = unique_name('t'); std::string loop_extent_name = unique_name('t'); if (!t.loop_var.empty()) { + Expr min = Variable::make(Int(32), loop_min_name); + Expr extent = Variable::make(Int(32), loop_extent_name); t.body = For::make(t.loop_var, - Variable::make(Int(32), loop_min_name), - Variable::make(Int(32), loop_extent_name), + min, + min + extent - 1, ForType::Serial, t.partition_policy, DeviceAPI::None, @@ -380,7 +382,7 @@ struct LowerParallelTasks : public IRMutator { result.emplace_back(std::move(t)); } else if (loop && loop->for_type == ForType::Parallel) { add_suffix(prefix, ".par_for." + loop->name); - ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix), loop->partition_policy}; + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent(), const_false(), task_debug_name(prefix), loop->partition_policy}; result.emplace_back(std::move(t)); } else if (loop && loop->for_type == ForType::Serial && @@ -389,7 +391,7 @@ struct LowerParallelTasks : public IRMutator { const Variable *v = acquire->semaphore.as(); internal_assert(v); add_suffix(prefix, ".for." + v->name); - ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix), loop->partition_policy}; + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent(), const_true(), task_debug_name(prefix), loop->partition_policy}; while (acquire) { t.semaphores.push_back({acquire->semaphore, acquire->count}); t.body = acquire->body; diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index 2551fe0bffbb..e45f99f37e14 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -234,11 +234,11 @@ class DetermineAllocStride : public IRVisitor { void visit(const For *op) override { ScopedBinding - bind_bounds_if(is_const(op->min) && is_const(op->extent), - bounds, op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bind_bounds_if(is_const(op->min) && is_const(op->max), + bounds, op->name, Interval(op->min, op->max)); ScopedBinding bound_dependent_if((expr_uses_vars(op->min, dependent_vars) || - expr_uses_vars(op->extent, dependent_vars)), + expr_uses_vars(op->max, dependent_vars)), dependent_vars, op->name, Expr()); IRVisitor::visit(op); } @@ -372,16 +372,17 @@ class LowerWarpShuffles : public IRMutator { Stmt visit(const For *op) override { ScopedBinding - bind_if(is_const(op->min) && is_const(op->extent), - bounds, op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bind_if(is_const(op->min) && is_const(op->max), + bounds, op->name, Interval(op->min, op->max)); if (!this_lane.defined() && op->for_type == ForType::GPULane) { bool should_mask = false; ScopedValue old_warp_size(warp_size); + Expr extent = simplify(op->extent()); if (op->for_type == ForType::GPULane) { - auto loop_size = as_const_int(op->extent); + auto loop_size = as_const_int(extent); user_assert(loop_size && *loop_size <= 32) - << "CUDA gpu lanes loop must have constant extent of at most 32: " << op->extent << "\n"; + << "CUDA gpu lanes loop must have constant extent of at most 32: " << extent << "\n"; // Select a warp size - the smallest power of two that contains the loop size int64_t ws = 1; @@ -391,7 +392,7 @@ class LowerWarpShuffles : public IRMutator { should_mask = (ws != *loop_size); warp_size = make_const(Int(32), ws); } else { - warp_size = op->extent; + warp_size = extent; } this_lane_name = op->name; this_lane = Variable::make(Int(32), op->name); @@ -408,7 +409,8 @@ class LowerWarpShuffles : public IRMutator { // with storage striped across the warp lanes, so the // size required per-lane is the old size divided by // the number of lanes (rounded up). - Expr new_size = (alloc->extents[0] + op->extent - 1) / op->extent; + Expr extent = op->extent(); + Expr new_size = (alloc->extents[0] + extent - 1) / extent; new_size = simplify(new_size, true, bounds); new_size = find_constant_bound(new_size, Direction::Upper, bounds); auto sz = as_const_int(new_size); @@ -423,7 +425,7 @@ class LowerWarpShuffles : public IRMutator { if (should_mask) { // Mask off the excess lanes in the warp - body = IfThenElse::make(this_lane < op->extent, body, Stmt()); + body = IfThenElse::make(this_lane <= op->max, body, Stmt()); } // Wrap the hoisted warp-level allocations, at their new @@ -455,7 +457,7 @@ class LowerWarpShuffles : public IRMutator { } allocations.clear(); - return For::make(op->name, op->min, warp_size, + return For::make(op->name, op->min, op->min + warp_size - 1, op->for_type, op->partition_policy, op->device_api, body); } else { return IRMutator::visit(op); @@ -732,7 +734,7 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } else { debug(3) << "Successfully hoisted shuffle out of for loop\n"; } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } Stmt visit(const Store *op) override { diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index 4a33c8f1bc00..e93f67e8bfef 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -15,6 +15,7 @@ #include "IRPrinter.h" #include "InjectHostDevBufferCopies.h" #include "OffloadGPULoops.h" +#include "Simplify.h" #include "Util.h" namespace Halide { @@ -55,10 +56,10 @@ class ExtractBounds : public IRVisitor { for (int i = 0; i < 3; i++) { if (ends_with(op->name, gpu_thread_name(i))) { - num_threads[i] = op->extent; + num_threads[i] = simplify(op->extent()); } if (ends_with(op->name, gpu_block_name(i))) { - num_blocks[i] = op->extent; + num_blocks[i] = simplify(op->extent()); } } diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index cc9bda77d586..1a8e001f0e1c 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -406,7 +406,7 @@ class FindSimplifications : public IRVisitor { for (Simplification &s : simplifications) { if (expr_uses_var(s.condition, op->name)) { Scope varying; - varying.push(op->name, Interval(op->min, op->min + op->extent - 1)); + varying.push(op->name, Interval(op->min, op->max)); Expr relaxed = and_condition_over_domain(s.condition, varying); internal_assert(!expr_uses_var(relaxed, op->name)) << "Should not have had used the loop var (" << op->name @@ -707,8 +707,8 @@ class PartitionLoops : public IRMutator { Stmt prologue = MakeSimplifications(prologue_simps).mutate(body); Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body); - bool make_prologue = !equal(prologue, simpler_body); - bool make_epilogue = !equal(epilogue, simpler_body); + const bool make_prologue = !equal(prologue, simpler_body); + const bool make_epilogue = !equal(epilogue, simpler_body); // Recurse on the middle section. simpler_body = mutate(simpler_body); @@ -721,10 +721,11 @@ class PartitionLoops : public IRMutator { } // Construct variables for the bounds of the simplified middle section - Expr min_steady = op->min, max_steady = op->extent + op->min; + const Expr original_max_plus_one = op->max + 1; + Expr min_steady = op->min, max_steady = original_max_plus_one; Expr prologue_val, epilogue_val; - string prologue_name = unique_name(op->name + ".prologue"); - string epilogue_name = unique_name(op->name + ".epilogue"); + const string prologue_name = unique_name(op->name + ".prologue"); + const string epilogue_name = unique_name(op->name + ".epilogue"); if (make_prologue) { // They'll simplify better if you put them in @@ -735,7 +736,7 @@ class PartitionLoops : public IRMutator { min_vals.push_back(op->min); prologue_val = fold_left(min_vals, Max::make); // Stop the prologue from running past the end of the loop - prologue_val = min(prologue_val, op->extent + op->min); + prologue_val = min(prologue_val, original_max_plus_one); // prologue_val = print(prologue_val, prologue_name); min_steady = Variable::make(Int(32), prologue_name); @@ -743,7 +744,7 @@ class PartitionLoops : public IRMutator { } if (make_epilogue) { std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare()); - max_vals.push_back(op->min + op->extent - 1); + max_vals.push_back(op->max); epilogue_val = fold_left(max_vals, Min::make) + 1; // Stop the epilogue from running before the start of the loop/prologue if (make_prologue) { @@ -760,17 +761,17 @@ class PartitionLoops : public IRMutator { Stmt stmt; // Bust simple serial for loops up into three. if (op->for_type == ForType::Serial && !op->body.as()) { - stmt = For::make(op->name, min_steady, max_steady - min_steady, + stmt = For::make(op->name, min_steady, max_steady - 1, op->for_type, op->partition_policy, op->device_api, simpler_body); if (make_prologue) { - prologue = For::make(op->name, op->min, min_steady - op->min, + prologue = For::make(op->name, op->min, min_steady - 1, op->for_type, op->partition_policy, op->device_api, prologue); stmt = Block::make(prologue, stmt); mutated = true; } if (make_epilogue) { - epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady, + epilogue = For::make(op->name, max_steady, op->max, op->for_type, op->partition_policy, op->device_api, epilogue); stmt = Block::make(stmt, epilogue); mutated = true; @@ -803,19 +804,19 @@ class PartitionLoops : public IRMutator { mutated = true; } } - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, stmt); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, stmt); } if (make_epilogue) { // Uncomment to include code that prints the epilogue value - // epilogue_val = print(epilogue_val, op->name, "epilogue"); + // epilogue_val = print(epilogue_val, op->name, "epilogue", op->min, op->max); stmt = LetStmt::make(epilogue_name, epilogue_val, stmt); } else { - epilogue_val = op->min + op->extent; + epilogue_val = original_max_plus_one; } if (make_prologue) { // Uncomment to include code that prints the prologue value - // prologue_val = print(prologue_val, op->name, "prologue"); + // prologue_val = print(prologue_val, op->name, "prologue", op->min, op->max); stmt = LetStmt::make(prologue_name, prologue_val, stmt); } else { prologue_val = op->min; @@ -924,9 +925,9 @@ class RenormalizeGPULoops : public IRMutator { // Move lets in-between gpu loop levels inwards. if (f && in_gpu_loop && !in_thread_loop) { internal_assert(!expr_uses_var(f->min, op->name) && - !expr_uses_var(f->extent, op->name)); + !expr_uses_var(f->max, op->name)); Stmt inner = LetStmt::make(op->name, op->value, f->body); - inner = For::make(f->name, f->min, f->extent, f->for_type, f->partition_policy, f->device_api, inner); + inner = For::make(f->name, f->min, f->max, f->for_type, f->partition_policy, f->device_api, inner); return mutate(inner); } else if (a && in_gpu_loop && !in_thread_loop) { internal_assert(a->extents.size() == 1); @@ -1002,9 +1003,9 @@ class RenormalizeGPULoops : public IRMutator { } else if (for_a && for_b && for_a->name == for_b->name && for_a->min.same_as(for_b->min) && - for_a->extent.same_as(for_b->extent)) { + for_a->max.same_as(for_b->max)) { Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body); - inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->partition_policy, for_a->device_api, inner); + inner = For::make(for_a->name, for_a->min, for_a->max, for_a->for_type, for_a->partition_policy, for_a->device_api, inner); return mutate(inner); } else { internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n"; diff --git a/src/Prefetch.cpp b/src/Prefetch.cpp index a34e95f3c530..c0eedf50b817 100644 --- a/src/Prefetch.cpp +++ b/src/Prefetch.cpp @@ -246,7 +246,7 @@ class InjectPlaceholderPrefetch : public IRMutator { Stmt stmt; if (!body.same_as(op->body)) { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, std::move(body)); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } else { stmt = op; } @@ -300,7 +300,7 @@ class ReducePrefetchDimension : public IRMutator { stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic)); for (size_t i = 0; i < index_names.size(); ++i) { - stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2], + stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2] - 1, ForType::Serial, Partition::Auto, DeviceAPI::None, stmt); } debug(5) << "\nReduce prefetch to " << max_dim << " dim:\n" @@ -371,7 +371,7 @@ class SplitPrefetch : public IRMutator { vector args = {base, std::move(new_offset), std::move(new_extent), std::move(new_stride)}; stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic)); for (size_t i = 0; i < index_names.size(); ++i) { - stmt = For::make(index_names[i], 0, extents[i], + stmt = For::make(index_names[i], 0, extents[i] - 1, ForType::Serial, Partition::Auto, DeviceAPI::None, stmt); } debug(5) << "\nSplit prefetch to max of " << max_byte_size << " bytes:\n" diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 9d38efaaf80a..1c76965ac2be 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -91,24 +91,23 @@ class PrintLoopNest : public IRVisitor { // If the min or extent are constants, print them. At this // stage they're all variables. - Expr min_val = op->min, extent_val = op->extent; + Expr min_val = op->min, max_val = op->max; const Variable *min_var = min_val.as(); - const Variable *extent_var = extent_val.as(); + const Variable *max_var = max_val.as(); if (min_var) { if (const Expr *e = constants.find(min_var->name)) { min_val = *e; } } - if (extent_var) { - if (const Expr *e = constants.find(extent_var->name)) { - extent_val = *e; + if (max_var) { + if (const Expr *e = constants.find(max_var->name)) { + max_val = *e; } } - if (extent_val.defined() && is_const(extent_val) && + if (max_val.defined() && is_const(max_val) && min_val.defined() && is_const(min_val)) { - Expr max_val = simplify(min_val + extent_val - 1); out << " in [" << min_val << ", " << max_val << "]"; } diff --git a/src/Profiling.cpp b/src/Profiling.cpp index 3c7b5d6e0090..c8054e83544e 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -502,7 +502,7 @@ class InjectProfiling : public IRMutator { most_recently_set_func = -1; } - Stmt stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + Stmt stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); if (update_active_threads) { if (Internal::is_gpu(op->for_type)) { diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp index d20c1e42ce3a..49f97126bb93 100644 --- a/src/RebaseLoopsToZero.cpp +++ b/src/RebaseLoopsToZero.cpp @@ -31,7 +31,6 @@ class RebaseLoopsToZero : public IRMutator { Stmt body = mutate(op->body); string name = op->name; if (!is_const_zero(op->min)) { - // Renaming the loop (intentionally) invalidates any .loop_min/.loop_max lets. name = op->name + ".rebased"; Expr loop_var = Variable::make(Int(32), name); body = LetStmt::make(op->name, loop_var + op->min, body); @@ -39,7 +38,7 @@ class RebaseLoopsToZero : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(name, 0, op->max - op->min, op->for_type, op->partition_policy, op->device_api, body); } } }; diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index c6036b99ec5a..9667aafe891a 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -355,8 +355,8 @@ class RemoveUndef : public IRMutator { if (!min.defined()) { return Stmt(); } - Expr extent = mutate(op->extent); - if (!extent.defined()) { + Expr max = mutate(op->max); + if (!max.defined()) { return Stmt(); } Stmt body = mutate(op->body); @@ -364,11 +364,11 @@ class RemoveUndef : public IRMutator { return Stmt(); } if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } else { - return For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 19c3de055001..69001feebf92 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -418,8 +418,8 @@ Stmt build_loop_nest( internal_assert(container.type == Container::For); const Dim &dim = stage_s.dims()[container.dim_idx]; Expr min = Variable::make(Int(32), container.name + ".loop_min"); - Expr extent = Variable::make(Int(32), container.name + ".loop_extent"); - stmt = For::make(container.name, min, extent, dim.for_type, dim.partition_policy, dim.device_api, stmt); + Expr max = Variable::make(Int(32), container.name + ".loop_max"); + stmt = For::make(container.name, min, max, dim.for_type, dim.partition_policy, dim.device_api, stmt); } } @@ -437,17 +437,18 @@ Stmt build_loop_nest( string o = prefix + Var::outermost().name(); stmt = LetStmt::make(o + ".loop_min", 0, stmt); stmt = LetStmt::make(o + ".loop_max", 0, stmt); - stmt = LetStmt::make(o + ".loop_extent", 1, stmt); } - // Define the loop mins and extents in terms of the mins and maxs produced by bounds inference + // Define the loop mins and extents in terms of the mins and maxs produced + // by bounds inference. These are simple new_var = old_var lets, but we + // can't just substitute because there are shadowed copies of .min/.max and + // the loop_min and loop_max must be in terms of the .min/.max at *this* + // loop level. + for (const std::string &i : dims) { string var = prefix + i; Expr max = Variable::make(Int(32), var + ".max"); Expr min = Variable::make(Int(32), var + ".min"); // Inject instance name here? (compute instance names during lowering) - stmt = LetStmt::make(var + ".loop_extent", - (max + 1) - min, - stmt); stmt = LetStmt::make(var + ".loop_min", min, stmt); stmt = LetStmt::make(var + ".loop_max", max, stmt); } @@ -460,7 +461,6 @@ Stmt build_loop_nest( Expr rmax = Variable::make(Int(32), p + ".max"); stmt = LetStmt::make(p + ".loop_min", rmin, stmt); stmt = LetStmt::make(p + ".loop_max", rmax, stmt); - stmt = LetStmt::make(p + ".loop_extent", rmax - rmin + 1, stmt); } return stmt; @@ -983,7 +983,7 @@ class InjectStmt : public IRMutator { } else { return For::make(for_loop->name, for_loop->min, - for_loop->extent, + for_loop->max, for_loop->for_type, for_loop->partition_policy, for_loop->device_api, @@ -1013,23 +1013,19 @@ Stmt inject_stmt(Stmt root, Stmt injected, const LoopLevel &level) { class CollectBounds : public IRVisitor { public: template - static map collect_bounds(const T &node) { + static map collect_bounds(const T &node) { CollectBounds bounds; node.accept(&bounds); return bounds.bounds; } private: - map bounds; + map bounds; using IRVisitor::visit; - void visit(const LetStmt *op) override { - if (ends_with(op->name, ".loop_min") || - ends_with(op->name, ".loop_max") || - ends_with(op->name, ".loop_extent")) { - bounds.emplace(op->name, Variable::make(Int(32), op->name)); - } + void visit(const For *op) override { + bounds.emplace(op->name, Interval{op->min, op->max}); IRVisitor::visit(op); } }; @@ -1047,71 +1043,50 @@ string fused_name(const string &var) { // The bounds of every loop exist in 'replacements' should be replaced. The // loop is also renamed by adding '.fused' in the original name before the // variable name. -Stmt substitute_fused_bounds(Stmt s, const map &replacements) { +Stmt substitute_fused_bounds(Stmt s, const map &replacements) { if (!s.defined() || replacements.empty()) { return s; } class SubstituteFusedBounds : public IRMutator { - const map &replacements; + const map &replacements; using IRMutator::visit; Stmt visit(const For *op) override { - const auto *min_var = op->min.as(); - const auto *extent_var = op->extent.as(); - if (min_var && extent_var) { - Expr min_val, extent_val; - { - const auto &it = replacements.find(min_var->name); - if (it != replacements.end()) { - min_val = it->second; - } - } - { - const auto &it = replacements.find(extent_var->name); - if (it != replacements.end()) { - extent_val = it->second; - } - } - if (!min_val.defined() || !extent_val.defined()) { - return IRMutator::visit(op); - } + auto it = replacements.find(op->name); + if (it == replacements.end()) { + return IRMutator::visit(op); + } + const Interval &i = it->second; - Stmt body = mutate(op->body); + Stmt body = mutate(op->body); - string new_var = fused_name(op->name); + string new_var = fused_name(op->name); + + ForType for_type = op->for_type; + DeviceAPI device_api = op->device_api; + if (equal(i.min, i.max)) { + // This is the child loop of a fused group. The real loop of the + // fused group is the loop of the parent function of the fused + // group. This child loop is just a scheduling point, and should + // never be a device transition, so we rewrite it to be a simple + // serial loop of extent 1." + for_type = ForType::Serial; + device_api = DeviceAPI::None; + } - ForType for_type = op->for_type; - DeviceAPI device_api = op->device_api; - if (is_const_one(extent_val)) { - // This is the child loop of a fused group. The real loop of the - // fused group is the loop of the parent function of the fused - // group. This child loop is just a scheduling point, and should - // never be a device transition, so we rewrite it to be a simple - // serial loop of extent 1." - for_type = ForType::Serial; - device_api = DeviceAPI::None; - } + Stmt stmt = For::make(new_var, i.min, i.max, + for_type, op->partition_policy, + device_api, body); - Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_extent"), - for_type, op->partition_policy, device_api, body); - - // Add let stmts defining the bound of the renamed for-loop. - stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); - stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); - // Replace any reference to the old loop name with the new one. - stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); - return stmt; - } else { - return IRMutator::visit(op); - } + // Replace any reference to the old loop name with the new one. + stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); + return stmt; } public: - explicit SubstituteFusedBounds(const map &r) + explicit SubstituteFusedBounds(const map &r) : replacements(r) { } } subs(replacements); @@ -1143,7 +1118,7 @@ Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_alias body = LetStmt::make(alias, var, body); } - return For::make(op->name, op->min, op->extent, op->for_type, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } @@ -1171,7 +1146,7 @@ class ShiftLoopNest : public IRMutator { internal_assert(op); Expr adjusted = Variable::make(Int(32), op->name) + iter->second; Stmt body = substitute(op->name, adjusted, op->body); - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } return stmt; } @@ -1347,7 +1322,7 @@ class InjectFunctionRealization : public IRMutator { } else { return For::make(for_loop->name, for_loop->min, - for_loop->extent, + for_loop->max, for_loop->for_type, for_loop->partition_policy, for_loop->device_api, @@ -1443,7 +1418,7 @@ class InjectFunctionRealization : public IRMutator { // Compute the shift factor required to align iteration of // a function stage with its fused parent loop nest. void compute_shift_factor(const Function &f, const string &prefix, const Definition &def, - map &bounds, map &shifts) { + map &bounds, map &shifts) { if (!def.defined()) { return; } @@ -1496,29 +1471,28 @@ class InjectFunctionRealization : public IRMutator { internal_assert(parent_var_index >= 0); string parent_var = parent_dims[parent_var_index].var; - auto it_min = bounds.find(prefix + var + ".loop_min"); - auto it_max = bounds.find(prefix + var + ".loop_max"); - internal_assert((it_min != bounds.end()) && (it_max != bounds.end())); + auto it = bounds.find(prefix + var); + internal_assert(it != bounds.end()); if (iter->second == LoopAlignStrategy::AlignStart) { - auto parent_min = bounds.find(parent_prefix + parent_var + ".loop_min"); + auto parent_min = bounds.find(parent_prefix + parent_var); internal_assert(parent_min != bounds.end()); - shift_val = parent_min->second - it_min->second; + shift_val = parent_min->second.min - it->second.min; } else { - auto parent_max = bounds.find(parent_prefix + parent_var + ".loop_max"); + auto parent_max = bounds.find(parent_prefix + parent_var); internal_assert(parent_max != bounds.end()); - shift_val = parent_max->second - it_max->second; + shift_val = parent_max->second.max - it->second.max; } internal_assert(shift_val.defined()); shifts.emplace(prefix + var, simplify(-shift_val)); - it_min->second = simplify(shift_val + it_min->second); - it_max->second = simplify(shift_val + it_max->second); + it->second.min = simplify(shift_val + it->second.min); + it->second.max = simplify(shift_val + it->second.max); } } Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update, - map &replacements, + map &replacements, vector> &add_lets, map> &aliases) { const vector &dims = def.schedule().dims(); // From inner to outer @@ -1556,9 +1530,7 @@ class InjectFunctionRealization : public IRMutator { internal_assert(dim2_idx < (int)dims_2.size()); string var = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; - replacements.emplace(var + ".loop_extent", make_const(Int(32), 1)); - replacements.emplace(var + ".loop_min", val); - replacements.emplace(var + ".loop_max", val); + replacements.emplace(var, Interval::single_point(val)); string var_fused = fused_name(var_orig); aliases[var_fused].emplace(std::move(var_orig)); @@ -1616,8 +1588,8 @@ class InjectFunctionRealization : public IRMutator { // realized in the group) with union of the bounds of the fused group. Stmt replace_parent_bound_with_union_bound(const string &func, int stage, const Definition &def, Stmt produce, - const map &bounds, - map &replacements) { + const map &bounds, + map &replacements) { if (def.schedule().fused_pairs().empty()) { return produce; @@ -1650,33 +1622,20 @@ class InjectFunctionRealization : public IRMutator { string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; - internal_assert(bounds.count(var_2 + ".loop_min")); - internal_assert(bounds.count(var_2 + ".loop_max")); - internal_assert(bounds.count(var_2 + ".loop_extent")); - Expr min_2 = bounds.find(var_2 + ".loop_min")->second; - Expr max_2 = bounds.find(var_2 + ".loop_max")->second; - Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second; - - internal_assert(bounds.count(var_1 + ".loop_min")); - internal_assert(bounds.count(var_1 + ".loop_max")); - internal_assert(bounds.count(var_1 + ".loop_extent")); - - Expr min_1, max_1; - const auto &it = replacements.find(var_1 + ".loop_min"); + + Interval i_1; + Interval i_2 = bounds.find(var_2)->second; + + const auto &it = replacements.find(var_1); if (it == replacements.end()) { - min_1 = bounds.find(var_1 + ".loop_min")->second; - max_1 = bounds.find(var_1 + ".loop_max")->second; + i_1 = bounds.find(var_1)->second; } else { - min_1 = replacements.find(var_1 + ".loop_min")->second; - max_1 = replacements.find(var_1 + ".loop_max")->second; + i_1 = it->second; } // Extent is computed from min/max, so we don't find() it earlier. - replacements[var_1 + ".loop_min"] = simplify(min(min_1, min_2)); - replacements[var_1 + ".loop_max"] = simplify(max(max_1, max_2)); - replacements[var_1 + ".loop_extent"] = - simplify((replacements[var_1 + ".loop_max"] + 1) - - replacements[var_1 + ".loop_min"]); + replacements[var_1] = Interval{simplify(min(i_1.min, i_2.min)), + simplify(max(i_1.max, i_2.max))}; } } @@ -1690,8 +1649,8 @@ class InjectFunctionRealization : public IRMutator { } Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, - const map &bounds) { - map replacements; + const map &bounds) { + map replacements; int stage = 0; produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements); @@ -1828,7 +1787,7 @@ class InjectFunctionRealization : public IRMutator { // Build the loops. Stmt producer; - map replacements; + map replacements; vector> add_lets; map> aliases; @@ -1954,7 +1913,7 @@ class ComputeLegalSchedules : public IRVisitor { sites.push_back({f->is_parallel(), is_gpu_block, loop_level}); f->min.accept(this); - f->extent.accept(this); + f->max.accept(this); f->body.accept(this); sites.pop_back(); @@ -2556,7 +2515,7 @@ class RemoveLoopsOverOutermost : public IRMutator { Stmt visit(const For *op) override { if (ends_with(op->name, ".__outermost") && - is_const_one(simplify(op->extent)) && + can_prove(op->min == op->max) && op->device_api == DeviceAPI::None) { return mutate(substitute(op->name, op->min, op->body)); } else { @@ -2565,8 +2524,7 @@ class RemoveLoopsOverOutermost : public IRMutator { } Stmt visit(const LetStmt *op) override { - if (ends_with(op->name, ".__outermost.loop_extent") || - ends_with(op->name, ".__outermost.loop_min") || + if (ends_with(op->name, ".__outermost.loop_min") || ends_with(op->name, ".__outermost.loop_max")) { return mutate(substitute(op->name, simplify(op->value), op->body)); } else { @@ -2590,7 +2548,7 @@ Stmt schedule_functions(const vector &outputs, const Target &target, bool &any_memoized) { string root_var = LoopLevel::root().lock().to_string(); - Stmt s = For::make(root_var, 0, 1, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); + Stmt s = For::make(root_var, 0, 0, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); any_memoized = false; diff --git a/src/SelectGPUAPI.cpp b/src/SelectGPUAPI.cpp index 504bda8f3bb8..ec73c883e955 100644 --- a/src/SelectGPUAPI.cpp +++ b/src/SelectGPUAPI.cpp @@ -35,7 +35,7 @@ class SelectGPUAPI : public IRMutator { internal_assert(op); if (op->device_api != selected_api) { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, selected_api, op->body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, selected_api, op->body); } return stmt; } diff --git a/src/Serialization.cpp b/src/Serialization.cpp index d731d9c9d85c..33de404edffe 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -444,7 +444,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu const auto *const for_stmt = stmt.as(); const auto name_serialized = serialize_string(builder, for_stmt->name); const auto min_serialized = serialize_expr(builder, for_stmt->min); - const auto extent_serialized = serialize_expr(builder, for_stmt->extent); + const auto max_serialized = serialize_expr(builder, for_stmt->max); const Serialize::ForType for_type = serialize_for_type(for_stmt->for_type); const Serialize::Partition partition_policy = serialize_partition(for_stmt->partition_policy); const Serialize::DeviceAPI device_api = serialize_device_api(for_stmt->device_api); @@ -452,7 +452,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu return std::make_pair(Serialize::Stmt::For, Serialize::CreateFor(builder, name_serialized, min_serialized.first, min_serialized.second, - extent_serialized.first, extent_serialized.second, + max_serialized.first, max_serialized.second, for_type, partition_policy, device_api, body_serialized.first, body_serialized.second) .Union()); diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index cd2c440de6ba..eaed86fa9a2e 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -145,7 +145,7 @@ Stmt Simplify::visit(const IfThenElse *op) { then_case); } else if (then_for && !else_case.defined() && - equal(unwrapped_condition, 0 < then_for->extent)) { + equal(unwrapped_condition, then_for->min <= then_for->max)) { // This guard is redundant return then_case; } else if (then_if && @@ -203,21 +203,21 @@ Stmt Simplify::visit(const AssertStmt *op) { } Stmt Simplify::visit(const For *op) { - ExprInfo min_info, extent_info; + ExprInfo min_info, max_info; Expr new_min = mutate(op->min, &min_info); if (in_unreachable) { return Evaluate::make(new_min); } - Expr new_extent = mutate(op->extent, &extent_info); + Expr new_max = mutate(op->max, &max_info); if (in_unreachable) { - return Evaluate::make(new_extent); + return Evaluate::make(new_max); } ScopedValue old_in_vector_loop(in_vector_loop, (in_vector_loop || op->for_type == ForType::Vectorized)); - Expr extent_positive = mutate(0 < new_extent, nullptr); + Expr extent_positive = mutate(new_min <= new_max, nullptr); if (is_const_zero(extent_positive)) { // This loop never runs return Evaluate::make(0); @@ -229,8 +229,8 @@ Stmt Simplify::visit(const For *op) { // at least one, so we can throw a max around the extent bounds. loop_var_info.bounds = - ConstantInterval::make_union(min_info.bounds, - min_info.bounds + max(extent_info.bounds, 1) - 1); + ConstantInterval::make_union(min(min_info.bounds, max_info.bounds), + max(min_info.bounds, max_info.bounds)); Stmt new_body; { ScopedBinding bind_if((loop_var_info.bounds.max_defined || @@ -244,10 +244,8 @@ Stmt Simplify::visit(const For *op) { // The loop variable will never exceed the loop bound. Expr loop_var = Variable::make(Int(32), op->name); - Expr new_max = mutate(new_min + new_extent, nullptr); - ScopedFact fact_loop_var_less_than_extent = scoped_truth(loop_var < new_max); - - ScopedFact fact_loop_var_ge_than_min = scoped_truth(new_min <= loop_var); + ScopedFact fact_loop_var_le_max = scoped_truth(loop_var <= new_max); + ScopedFact fact_loop_var_ge_min = scoped_truth(new_min <= loop_var); new_body = mutate(op->body); } @@ -258,38 +256,45 @@ Stmt Simplify::visit(const For *op) { // extent is greater than zero, then the code *outside* the loop must be // unreachable too, because if it weren't, it'd run the unreachable body // at least once. - in_unreachable = extent_info.bounds > 0; + in_unreachable = max_info.bounds >= min_info.bounds; return Evaluate::make(0); } if (const Acquire *acquire = new_body.as()) { if (is_no_op(acquire->body)) { // Rewrite iterated no-op acquires as a single acquire. - return Acquire::make(acquire->semaphore, mutate(acquire->count * new_extent, nullptr), acquire->body); + return Acquire::make(acquire->semaphore, mutate(acquire->count * ((new_max - new_min) + 1), nullptr), acquire->body); } } if (is_no_op(new_body)) { return new_body; - } else if (extent_info.bounds <= 0) { + } else if (max_info.bounds < min_info.bounds) { return Evaluate::make(0); - } else if (extent_info.bounds <= 1 && + } else if (equal(new_min, new_max) && + op->device_api == DeviceAPI::None) { + // Loop body runs exactly once + return mutate(LetStmt::make(op->name, new_min, new_body)); + } else if (max_info.bounds <= min_info.bounds && op->device_api == DeviceAPI::None) { // Loop body runs at most once Stmt s = LetStmt::make(op->name, new_min, new_body); - if (extent_info.bounds.contains(0)) { + if (!(max_info.bounds >= min_info.bounds)) { // Loop body might not run at all - s = IfThenElse::make(0 < new_extent, s); + s = IfThenElse::make(new_min <= new_max, s); } return mutate(s); - } else if (!stmt_uses_var(new_body, op->name) && !is_const_zero(op->min)) { - return For::make(op->name, make_zero(Int(32)), new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + } else if (Expr shifted_max; + !stmt_uses_var(new_body, op->name) && + !is_const_zero(new_min) && + is_const(shifted_max = mutate((new_max - new_min), nullptr))) { + return For::make(op->name, make_zero(Int(32)), shifted_max, op->for_type, op->partition_policy, op->device_api, new_body); } else if (op->min.same_as(new_min) && - op->extent.same_as(new_extent) && + op->max.same_as(new_max) && op->body.same_as(new_body)) { return op; } else { - return For::make(op->name, new_min, new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, new_body); } } diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 63640492de36..4a82456689c0 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -721,7 +721,7 @@ class SkipStages : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } } diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 69fa3198ceaf..3f779cb0bca2 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -198,7 +198,7 @@ class RollFunc : public IRMutator { Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); // use op->name *before* the re-assignment of result, which will clobber it loops_to_rebase.erase(op->name); - result = For::make(new_name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + result = For::make(new_name, 0, op->max - op->min, op->for_type, op->partition_policy, op->device_api, body); } return result; } @@ -561,21 +561,21 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { // It's not safe to enter an inner loop whose bounds depend on // the var we're sliding over. Expr min = expand_expr(op->min, scope); - Expr extent = expand_expr(op->extent, scope); + Expr max = expand_expr(op->max, scope); ScopedBinding<> bind(enclosing_loops, op->name); - if (is_const_one(extent)) { + if (equal(min, max)) { // Just treat it like a let Stmt s = LetStmt::make(op->name, min, op->body); s = mutate(s); // Unpack it back into the for const LetStmt *l = s.as(); internal_assert(l); - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, l->body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, l->body); } else if (is_monotonic(min, loop_var) != Monotonic::Constant || - is_monotonic(extent, loop_var) != Monotonic::Constant) { + is_monotonic(max, loop_var) != Monotonic::Constant) { debug(3) << "Not entering loop over " << op->name << " because the bounds depend on the var we're sliding over: " - << min << ", " << extent << "\n"; + << min << ", " << max << "\n"; return op; } else { return IRMutator::visit(op); @@ -793,8 +793,7 @@ class SlidingWindow : public IRMutator { string name = op->name; Stmt body = op->body; Expr loop_min = op->min; - Expr loop_extent = op->extent; - Expr loop_max = Variable::make(Int(32), op->name + ".loop_max"); + Expr loop_max = op->max; list> prev_loop_mins; list> new_lets; @@ -841,21 +840,18 @@ class SlidingWindow : public IRMutator { // Update the loop body to use the adjusted loop min. string new_name = name + ".$n"; loop_min = Variable::make(Int(32), new_name + ".loop_min"); - loop_extent = Variable::make(Int(32), new_name + ".loop_extent"); body = substitute({ {name, Variable::make(Int(32), new_name)}, {name + ".loop_min", loop_min}, - {name + ".loop_extent", loop_extent}, }, body); body = SubstitutePrefetchVar(name, new_name).mutate(body); name = new_name; - // The new loop interval is the new loop min to the loop max. + // The new loop interval is the new loop min to the old loop max. new_lets.emplace_front(name + ".loop_min", new_loop_min); new_lets.emplace_front(name + ".loop_min.orig", loop_min); - new_lets.emplace_front(name + ".loop_extent", (loop_max - loop_min) + 1); } if (slid_dims.size() > old_slid_dims_size) { @@ -870,13 +866,10 @@ class SlidingWindow : public IRMutator { body = mutate(body); - if (body.same_as(op->body) && loop_min.same_as(op->min) && loop_extent.same_as(op->extent) && name == op->name) { + if (body.same_as(op->body) && loop_min.same_as(op->min) && loop_max.same_as(op->max) && name == op->name) { return op; } else { - Stmt result = For::make(name, loop_min, loop_extent, op->for_type, op->partition_policy, op->device_api, body); - if (!new_lets.empty()) { - result = LetStmt::make(name + ".loop_max", loop_max, result); - } + Stmt result = For::make(name, loop_min, loop_max, op->for_type, op->partition_policy, op->device_api, body); for (const auto &i : new_lets) { result = LetStmt::make(i.first, i.second, result); } @@ -913,13 +906,13 @@ class AddLoopMinOrig : public IRMutator { Stmt visit(const For *op) override { Stmt body = mutate(op->body); Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt result; - if (body.same_as(op->body) && min.same_as(op->min) && extent.same_as(op->extent)) { + if (body.same_as(op->body) && min.same_as(op->min) && max.same_as(op->max)) { result = op; } else { - result = For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + result = For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } return LetStmt::make(op->name + ".loop_min.orig", Variable::make(Int(32), op->name + ".loop_min"), result); } diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index 5880ea3b4008..85691921bc8d 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -114,7 +114,7 @@ class FindStridedLoads : public IRVisitor { } void visit(const For *op) override { - if (can_prove(op->extent > 0)) { + if (can_prove(op->min <= op->max)) { // The loop body definitely runs IRVisitor::visit(op); } else { diff --git a/src/StmtToHTML.cpp b/src/StmtToHTML.cpp index fbb327fc2c68..766ea61973b2 100644 --- a/src/StmtToHTML.cpp +++ b/src/StmtToHTML.cpp @@ -424,8 +424,8 @@ class IRCostModel : public IRVisitor { // The cost of a loop-node essentially depends on its iteration // count. The cost model currently ignores such costs. IRVisitor::visit(op); - set_compute_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()}); - set_data_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()}); + set_compute_costs(op, 0, {op->min.get(), op->max.get(), op->body.get()}, {op->min.get(), op->max.get()}); + set_data_costs(op, 0, {op->min.get(), op->max.get(), op->body.get()}, {op->min.get(), op->max.get()}); } void visit(const Acquire *op) override { @@ -1694,7 +1694,7 @@ class HTMLCodePrinter : public IRVisitor { print_html_element("span", "matched", ", "); print(op->min); print_html_element("span", "matched", ", "); - print(op->extent); + print(op->max); print_html_element("span", "matched", ")"); // Open code block to hold function body diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index 11f5e8e20d24..a7c8f5208c22 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -569,12 +569,12 @@ class HoistStorage : public IRMutator { Stmt visit(const For *op) override { Expr expanded_min = op->min; - Expr expanded_extent = op->extent; + Expr expanded_max = op->max; // Iterate from innermost outwards for (auto &storage : reverse_view(hoisted_storages)) { expanded_min = simplify(expand_expr(expanded_min, storage.scope)); - expanded_extent = expand_expr(expanded_extent, storage.scope); - auto loop_bounds = Interval(expanded_min, simplify(expanded_min + expanded_extent - 1)); + expanded_max = expand_expr(expanded_max, storage.scope); + auto loop_bounds = Interval(expanded_min, expanded_max); storage.loop_vars.emplace_back(op->name, loop_bounds); } diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index 04e743e33fbd..d7e2db52e17e 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -536,16 +536,14 @@ class AttemptStorageFoldingOfFunction : public IRMutator { Box box = box_union(provided, required); Expr loop_var = Variable::make(Int(32), op->name); - Expr loop_min = Variable::make(Int(32), op->name + ".loop_min"); - Expr loop_max = Variable::make(Int(32), op->name + ".loop_max"); string dynamic_footprint; Scope bounds; - bounds.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bounds.push(op->name, Interval(op->min, op->max)); Scope steady_bounds; - steady_bounds.push(op->name, Interval(simplify(op->min + 1), simplify(op->min + op->extent - 1))); + steady_bounds.push(op->name, Interval(simplify(op->min + 1), op->max)); HasExternConsumer has_extern_consumer(func.name()); body.accept(&has_extern_consumer); @@ -735,7 +733,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { } else { // The max of the extent over all values of the loop variable must be a constant Scope scope; - scope.push(op->name, Interval(loop_min, loop_max)); + scope.push(op->name, Interval(op->min, op->max)); Expr max_extent = find_constant_bound(extent, Direction::Upper, scope); scope.pop(op->name); @@ -825,8 +823,8 @@ class AttemptStorageFoldingOfFunction : public IRMutator { // On the first iteration, we need to acquire the extent of the region shared // between the producer and consumer, and we need to release it on the last // iteration. - to_acquire = select(loop_var > loop_min, to_acquire, extent); - to_release = select(loop_var < loop_max, to_release, extent); + to_acquire = select(loop_var > op->min, to_acquire, extent); + to_release = select(loop_var < op->max, to_release, extent); // We may need dynamic assertions that a positive // amount of the semaphore is acquired/released, @@ -881,7 +879,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { // for further folding opportunities // recursively. } else if (!body.same_as(op->body)) { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); break; } else { stmt = op; @@ -900,7 +898,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { if (body.same_as(op->body)) { stmt = op; } else { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } if (func.schedule().async() && !dynamic_footprint.empty()) { diff --git a/src/Substitute.cpp b/src/Substitute.cpp index 6a7cba7fd589..603a7cc8aa5a 100644 --- a/src/Substitute.cpp +++ b/src/Substitute.cpp @@ -83,17 +83,17 @@ class Substitute : public IRMutator { Stmt visit(const For *op) override { Expr new_min = mutate(op->min); - Expr new_extent = mutate(op->extent); + Expr new_max = mutate(op->max); hidden.push(op->name); Stmt new_body = mutate(op->body); hidden.pop(op->name); if (new_min.same_as(op->min) && - new_extent.same_as(op->extent) && + new_max.same_as(op->max) && new_body.same_as(op->body)) { return op; } else { - return For::make(op->name, new_min, new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, new_body); } } }; diff --git a/src/TrimNoOps.cpp b/src/TrimNoOps.cpp index b4ec415072c9..1842a702fab4 100644 --- a/src/TrimNoOps.cpp +++ b/src/TrimNoOps.cpp @@ -125,12 +125,12 @@ class IsNoOp : public IRVisitor { condition = const_true(); op->body.accept(this); Scope varying; - varying.push(op->name, Interval(op->min, op->min + op->extent - 1)); + varying.push(op->name, Interval(op->min, op->max)); condition = simplify(common_subexpression_elimination(condition)); debug(3) << "About to relax over " << op->name << " : " << condition << "\n"; condition = and_condition_over_domain(condition, varying); debug(3) << "Relaxed: " << condition << "\n"; - condition = make_and(old_condition, make_or(condition, simplify(op->extent <= 0))); + condition = make_and(old_condition, make_or(condition, simplify(op->max < op->min))); } void visit(const IfThenElse *op) override { @@ -334,11 +334,11 @@ class SimplifyUsingBounds : public IRMutator { Stmt visit(const For *op) override { // Simplify the loop bounds. Expr min = mutate(op->min); - Expr extent = mutate(op->extent); - containing_loops.push_back({op->name, {min, min + extent - 1}}); + Expr max = mutate(op->max); + containing_loops.push_back({op->name, {min, max}}); Stmt body = mutate(op->body); containing_loops.pop_back(); - return For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } public: @@ -380,7 +380,7 @@ class TrimNoOps : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -393,7 +393,7 @@ class TrimNoOps : public IRMutator { if (i.is_everything()) { // Nope. - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } if (i.is_empty()) { @@ -414,29 +414,22 @@ class TrimNoOps : public IRMutator { Expr new_max_var = Variable::make(Int(32), new_max_name); Expr old_max_var = Variable::make(Int(32), old_max_name); - // Convert max to max-plus-one - if (i.has_upper_bound()) { - i.max = i.max + 1; - } - // Truncate the loop bounds to the region over which it's not // a no-op. - Expr old_max = op->min + op->extent; + Expr old_max = op->max; Expr new_min, new_max; if (i.has_lower_bound()) { - new_min = clamp(i.min, op->min, old_max_var); + new_min = clamp(i.min, op->min, old_max_var + 1); } else { new_min = op->min; } if (i.has_upper_bound()) { - new_max = clamp(i.max, new_min_var, old_max_var); + new_max = clamp(i.max, new_min_var - 1, old_max_var); } else { new_max = old_max; } - Expr new_extent = new_max_var - new_min_var; - - Stmt stmt = For::make(op->name, new_min_var, new_extent, op->for_type, op->partition_policy, op->device_api, body); + Stmt stmt = For::make(op->name, new_min_var, new_max_var, op->for_type, op->partition_policy, op->device_api, body); stmt = LetStmt::make(new_max_name, new_max, stmt); stmt = LetStmt::make(new_min_name, new_min, stmt); stmt = LetStmt::make(old_max_name, old_max, stmt); diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 9dc92c780b3a..91f0279de04c 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -88,7 +88,7 @@ class UniquifyVariableNames : public IRMutator { Stmt visit(const For *op) override { Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); string new_name = make_new_name(op->name); Stmt body = mutate(op->body); renaming.pop(op->name); @@ -96,10 +96,10 @@ class UniquifyVariableNames : public IRMutator { if (new_name == op->name && body.same_as(op->body) && min.same_as(op->min) && - extent.same_as(op->extent)) { + max.same_as(op->max)) { return op; } else { - return For::make(new_name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(new_name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -153,7 +153,7 @@ class FindFreeVars : public IRVisitor { void visit(const For *op) override { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); { ScopedBinding<> bind(scope, op->name); op->body.accept(this); diff --git a/src/UnrollLoops.cpp b/src/UnrollLoops.cpp index 2823c8b9ac9f..ffcba564966a 100644 --- a/src/UnrollLoops.cpp +++ b/src/UnrollLoops.cpp @@ -16,7 +16,8 @@ class UnrollLoops : public IRMutator { Stmt visit(const For *for_loop) override { if (for_loop->for_type == ForType::Unrolled) { Stmt body = for_loop->body; - const IntImm *e = for_loop->extent.as(); + Expr extent = simplify(for_loop->extent()); + const IntImm *e = extent.as(); internal_assert(e) << "Loop over " << for_loop->name << " should have had a constant extent\n"; diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index d06cc0815300..2d149adbaf20 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -353,7 +353,7 @@ class SerializeLoops : public IRMutator { Stmt visit(const For *op) override { if (op->for_type == ForType::Vectorized) { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, ForType::Serial, op->partition_policy, op->device_api, mutate(op->body)); } @@ -950,7 +950,7 @@ class VectorSubs : public IRMutator { ForType for_type = op->for_type; Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = op->body; @@ -958,21 +958,22 @@ class VectorSubs : public IRMutator { // Rebase the loop to zero and try again Expr var = Variable::make(Int(32), op->name); Stmt body = substitute(op->name, var + op->min, op->body); - Stmt transformed = For::make(op->name, 0, op->extent, for_type, op->partition_policy, op->device_api, body); + Stmt transformed = For::make(op->name, 0, simplify(op->max - op->min), for_type, op->partition_policy, op->device_api, body); return mutate(transformed); } - if (extent.type().is_vector()) { + if (max.type().is_vector()) { // We'll iterate up to the max over the lanes, but // inject an if statement inside the loop that stops // each lane from going too far. - extent = bounds_of_lanes(extent).max; + max = bounds_of_lanes(max).max; Expr var = Variable::make(Int(32), op->name); - body = IfThenElse::make(likely(var < op->min + op->extent), body); + body = IfThenElse::make(likely(var <= max), body); } if (op->for_type == ForType::Vectorized) { + Expr extent = simplify((max - min) + 1); const IntImm *extent_int = extent.as(); internal_assert(extent_int) << "Vectorized for loop extent should have been rewritten to a constant\n"; @@ -980,7 +981,11 @@ class VectorSubs : public IRMutator { user_error << "Loop over " << op->name << " has extent " << extent << ". Can only vectorize loops over a " - << "constant extent > 1\n"; + << "constant extent > 1\n" + << "Original min: " << op->min << "\n" + << "Original max: " << op->max << "\n" + << "Mutated min: " << min << "\n" + << "Mutated max: " << max << "\n"; } vectorized_vars.push_back({op->name, min, (int)extent_int->value}); @@ -1022,12 +1027,12 @@ class VectorSubs : public IRMutator { body = mutate(body); if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body) && for_type == op->for_type) { return op; } else { - return For::make(op->name, min, extent, for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, for_type, op->partition_policy, op->device_api, body); } } } @@ -1318,7 +1323,8 @@ class VectorSubs : public IRMutator { for (int ix = vectorized_vars.size() - 1; ix >= 0; ix--) { s = For::make(vectorized_vars[ix].name, vectorized_vars[ix].min, - vectorized_vars[ix].lanes, ForType::Serial, Partition::Auto, DeviceAPI::None, s); + vectorized_vars[ix].min + vectorized_vars[ix].lanes - 1, + ForType::Serial, Partition::Auto, DeviceAPI::None, s); } return s; @@ -1581,10 +1587,11 @@ class VectorizeLoops : public IRMutator { Stmt visit(const For *for_loop) override { Stmt stmt; if (for_loop->for_type == ForType::Vectorized) { - const IntImm *extent = for_loop->extent.as(); + Expr loop_extent = simplify(for_loop->extent()); + const IntImm *extent = loop_extent.as(); if (!extent || extent->value <= 1) { user_error << "Loop over " << for_loop->name - << " has extent " << for_loop->extent + << " has extent " << loop_extent << ". Can only vectorize loops over a " << "constant extent > 1\n"; } diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index 499488ce8b95..7f3492684743 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -7,7 +7,7 @@ file_identifier "HLDE"; file_extension "hlpipe"; enum SerializationVersionMajor: int { - Value = 18 + Value = 21 } enum SerializationVersionMinor: int { // 0 = Unstable @@ -15,7 +15,7 @@ enum SerializationVersionMinor: int { Value = 0 } enum SerializationVersionPatch: int { - Value = 1 + Value = 0 } // from src/IR.cpp @@ -143,7 +143,7 @@ table ProducerConsumer { table For { name: string; min: Expr; - extent: Expr; + max: Expr; for_type: ForType; partition_policy: Partition; device_api: DeviceAPI; diff --git a/test/correctness/fuse_gpu_threads.cpp b/test/correctness/fuse_gpu_threads.cpp index efd690c4ef4c..5c203846f0d9 100644 --- a/test/correctness/fuse_gpu_threads.cpp +++ b/test/correctness/fuse_gpu_threads.cpp @@ -9,9 +9,9 @@ class CheckThreadExtent : public IRVisitor { if (op->for_type == ForType::GPUThread) { // Assert the min and extent to be 0 and 16 for this particular test case auto min = as_const_int(op->min); - auto extent = as_const_int(op->extent); + auto max = as_const_int(op->max); assert(min && (*min == 0)); - assert(extent && (*extent == 16)); + assert(max && (*max == 15)); } IRVisitor::visit(op); } diff --git a/test/correctness/out_constraint.cpp b/test/correctness/out_constraint.cpp index 87dfa1a70df8..60fbeb8de5a9 100644 --- a/test/correctness/out_constraint.cpp +++ b/test/correctness/out_constraint.cpp @@ -26,9 +26,9 @@ class CheckLoops : public IRVisitor { using IRVisitor::visit; void visit(const For *op) override { - std::cout << "for(" << op->name << ", " << op->min << ", " << op->extent << ")\n"; + std::cout << "for(" << op->name << ", " << op->min << ", " << op->max << ")\n"; check_int(op->min, 0); - check_int(op->extent, size); + check_int(op->max, size - 1); ++count; IRVisitor::visit(op); } diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 052762184f46..08119357fb4b 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1736,10 +1736,10 @@ void check_boolean() { // A for loop is also an if statement that the extent is greater than zero Stmt body = AssertStmt::make(y == z, y); Stmt loop = For::make("t", 0, x, ForType::Serial, Partition::Auto, DeviceAPI::None, body); - check(IfThenElse::make(0 < x, loop), loop); + check(IfThenElse::make(0 <= x, loop), loop); - // A for loop where the extent is exactly one is just the body - check(IfThenElse::make(x == 1, loop), IfThenElse::make(x == 1, body)); + // A for loop where the min equals the max is just the body + check(IfThenElse::make(x == 0, loop), IfThenElse::make(x == 0, body)); // Check we can learn from conditions on variables check(IfThenElse::make(x < 5, not_no_op(min(x, 17))), @@ -2419,7 +2419,7 @@ int main(int argc, char **argv) { } { - Stmt body = AssertStmt::make(x > 0, y); + Stmt body = AssertStmt::make(x >= 0, y); check(For::make("t", 0, x, ForType::Serial, Partition::Auto, DeviceAPI::None, body), Evaluate::make(0)); }