From 2cd90bde2ad4c838c9a888e7a2e78784411fcad2 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Wed, 17 Sep 2025 14:27:11 -0400 Subject: [PATCH 1/2] Fix lowering of pow() in D3D12Compute when the base is provably positive. --- src/CodeGen_D3D12Compute_Dev.cpp | 6 ++++-- test/correctness/math.cpp | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index ad4f6451f918..e40dd1b316ef 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -388,10 +388,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { } else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) { // If we know pow(x, y) is called with x > 0, we can use HLSL's pow // directly. - stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")"; + Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern); + equiv.accept(this); } else if (op->is_intrinsic(Call::round)) { // HLSL's round intrinsic has the correct semantics for our rounding. - print_assignment(op->type, "round(" + print_expr(op->args[0]) + ")"); + Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern); + equiv.accept(this); } else { CodeGen_GPU_C::visit(op); } diff --git a/test/correctness/math.cpp b/test/correctness/math.cpp index 68ff3c0e56e8..685e855d9905 100644 --- a/test/correctness/math.cpp +++ b/test/correctness/math.cpp @@ -265,6 +265,33 @@ fun_2(uint32_t, uint32_t, absd, absd) call_2(double, name, steps, start1, end1, start2, end2); \ } while (0) +// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive. +// This test ensures that lowering is correct. +void test_pow_positive() { + printf("Testing pow(x, y) where x > 0\n"); + TestArgs args(256, 0.0f, 10.0f); + Func test_pow_positive("test_pow_positive"); + Var x("x"), xi("xi"); + test_pow_positive(x) = pow(1.5f, args.data(x)); + + Target target = get_jit_target_from_environment(); + if (target.has_gpu_feature()) { + test_pow_positive.gpu_tile(x, xi, 16); //.vectorize(xi, 2); + } else if (target.has_feature(Target::HVX)) { + test_pow_positive.hexagon(); + } + Buffer result = test_pow_positive.realize({args.data.extent(0)}, target); + for (int i = 0; i < args.data.extent(0); i++) { + float c_result = pow(1.5f, args.data(i)); + if (!relatively_equal(c_result, result(i), target)) { + fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n", + (double)args.data(i), (double)c_result, (double)result(i), + target.to_string().c_str()); + num_errors++; + } + } +} + } // namespace int main(int argc, char **argv) { @@ -299,6 +326,7 @@ int main(int argc, char **argv) { call_1_float_types(trunc, 256, -25, 25); call_2_float_types(pow, 256, -10.0, 10.0, -4.0f, 4.0f); + test_pow_positive(); const int8_t int8_min = std::numeric_limits::min(); const int16_t int16_min = std::numeric_limits::min(); From 3f2b3a34ed6519a7e92370fc510238e7d487f076 Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Wed, 17 Sep 2025 14:40:44 -0400 Subject: [PATCH 2/2] Fix formatting --- test/correctness/math.cpp | 54 +++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/test/correctness/math.cpp b/test/correctness/math.cpp index 685e855d9905..9d09f19c9193 100644 --- a/test/correctness/math.cpp +++ b/test/correctness/math.cpp @@ -118,6 +118,33 @@ struct TestArgs { } }; +// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive. +// This test ensures that lowering is correct. +void test_pow_positive() { + printf("Testing pow(x, y) where x > 0\n"); + TestArgs args(256, 0.0f, 10.0f); + Func test_pow_positive("test_pow_positive"); + Var x("x"), xi("xi"); + test_pow_positive(x) = pow(1.5f, args.data(x)); + + Target target = get_jit_target_from_environment(); + if (target.has_gpu_feature()) { + test_pow_positive.gpu_tile(x, xi, 16); + } else if (target.has_feature(Target::HVX)) { + test_pow_positive.hexagon(); + } + Buffer result = test_pow_positive.realize({args.data.extent(0)}, target); + for (int i = 0; i < args.data.extent(0); i++) { + float c_result = pow(1.5f, args.data(i)); + if (!relatively_equal(c_result, result(i), target)) { + fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n", + (double)args.data(i), (double)c_result, (double)result(i), + target.to_string().c_str()); + num_errors++; + } + } +} + // Using macros to expand name as both a C function and an Expr fragment. // It may well be possible to do this without macros, but that is left // for another day. @@ -265,33 +292,6 @@ fun_2(uint32_t, uint32_t, absd, absd) call_2(double, name, steps, start1, end1, start2, end2); \ } while (0) -// For D3D12Compute, we lower directly to HLSL's pow function if the base is provably positive. -// This test ensures that lowering is correct. -void test_pow_positive() { - printf("Testing pow(x, y) where x > 0\n"); - TestArgs args(256, 0.0f, 10.0f); - Func test_pow_positive("test_pow_positive"); - Var x("x"), xi("xi"); - test_pow_positive(x) = pow(1.5f, args.data(x)); - - Target target = get_jit_target_from_environment(); - if (target.has_gpu_feature()) { - test_pow_positive.gpu_tile(x, xi, 16); //.vectorize(xi, 2); - } else if (target.has_feature(Target::HVX)) { - test_pow_positive.hexagon(); - } - Buffer result = test_pow_positive.realize({args.data.extent(0)}, target); - for (int i = 0; i < args.data.extent(0); i++) { - float c_result = pow(1.5f, args.data(i)); - if (!relatively_equal(c_result, result(i), target)) { - fprintf(stderr, "For pow(1.5f, %.20f) == %.20f from C and %.20f from %s.\n", - (double)args.data(i), (double)c_result, (double)result(i), - target.to_string().c_str()); - num_errors++; - } - } -} - } // namespace int main(int argc, char **argv) {