@@ -243,6 +243,14 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
243
243
return nothing
244
244
end
245
245
246
+ function marked_const_along_path (res, path)
247
+ for p in path
248
+ res = traced_getfield (res, p)
249
+ res isa Enzyme. Const && return true
250
+ end
251
+ return res isa Enzyme. Const
252
+ end
253
+
246
254
function overload_autodiff (
247
255
:: CMode , f:: FA , :: Type{A} , args:: Vararg{Enzyme.Annotation,Nargs}
248
256
) where {CMode<: Enzyme.Mode ,FA<: Enzyme.Annotation ,A<: Enzyme.Annotation ,Nargs}
@@ -346,9 +354,15 @@ function overload_autodiff(
346
354
act = act_from_type (A, reverse, needs_primal (CMode))
347
355
push! (ret_activity, act)
348
356
if act == enzyme_out || act == enzyme_outnoneed
349
- attr = MLIR. IR. DenseElementsAttribute (
350
- fill (one (unwrapped_eltype (a)), size (a))
351
- )
357
+ fill_value =
358
+ if marked_const_along_path (
359
+ result, TracedUtils. get_idx (a, resprefix)[2 : end ]
360
+ )
361
+ zero (unwrapped_eltype (a))
362
+ else
363
+ one (unwrapped_eltype (a))
364
+ end
365
+ attr = MLIR. IR. DenseElementsAttribute (fill (fill_value, size (a)))
352
366
cst = MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
353
367
push! (ad_inputs, cst)
354
368
end
0 commit comments