Skip to content

Commit 14508ff

Browse files
committed
feat: support return type annotations for autodiff calls
1 parent 39fe3e6 commit 14508ff

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/Enzyme.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,14 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
243243
return nothing
244244
end
245245

246+
function marked_const_along_path(res, path)
247+
for p in path
248+
res = traced_getfield(res, p)
249+
res isa Enzyme.Const && return true
250+
end
251+
return res isa Enzyme.Const
252+
end
253+
246254
function overload_autodiff(
247255
::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs}
248256
) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs}
@@ -346,9 +354,15 @@ function overload_autodiff(
346354
act = act_from_type(A, reverse, needs_primal(CMode))
347355
push!(ret_activity, act)
348356
if act == enzyme_out || act == enzyme_outnoneed
349-
attr = MLIR.IR.DenseElementsAttribute(
350-
fill(one(unwrapped_eltype(a)), size(a))
351-
)
357+
fill_value =
358+
if marked_const_along_path(
359+
result, TracedUtils.get_idx(a, resprefix)[2:end]
360+
)
361+
zero(unwrapped_eltype(a))
362+
else
363+
one(unwrapped_eltype(a))
364+
end
365+
attr = MLIR.IR.DenseElementsAttribute(fill(fill_value, size(a)))
352366
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
353367
push!(ad_inputs, cst)
354368
end

0 commit comments

Comments
 (0)