diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 0bd5329a07..0820e01730 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -243,6 +243,14 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) return nothing end +function marked_const_along_path(res, path) + for p in path + res = traced_getfield(res, p) + res isa Enzyme.Const && return true + end + return res isa Enzyme.Const +end + function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs} ) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} @@ -346,9 +354,15 @@ function overload_autodiff( act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = MLIR.IR.DenseElementsAttribute( - fill(one(unwrapped_eltype(a)), size(a)) - ) + fill_value = + if marked_const_along_path( + result, TracedUtils.get_idx(a, resprefix)[2:end] + ) + zero(unwrapped_eltype(a)) + else + one(unwrapped_eltype(a)) + end + attr = MLIR.IR.DenseElementsAttribute(fill(fill_value, size(a))) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end diff --git a/test/autodiff.jl b/test/autodiff.jl index cb6e1db957..f0af62c904 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -255,3 +255,23 @@ end contains(repr(hlo), "stablehlo.rng_bit_generator") end end + +function simple_grad_with_no_ret_annotation(x::AbstractArray{T}) where {T} + return sum(x; dims=1), sum(abs2, x) +end + +function simple_grad_with_ret_annotation(x::AbstractArray{T}) where {T} + return Const(sum(x; dims=1)), sum(abs2, x) +end + +@testset "return annotation" begin + x = Reactant.to_rarray(rand(Float32, 4, 4)) + + res1 = only(@jit(Enzyme.gradient(Reverse, simple_grad_with_no_ret_annotation, x))) + + @test res1 ≈ (2 .* Array(x) .+ 1) + + res2 = only(@jit(Enzyme.gradient(Reverse, simple_grad_with_ret_annotation, x))) + + @test res2 ≈ (2 .* Array(x)) +end