From 63c7908cf9bc01f409523e7467e09b820504e3c4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 28 Jul 2025 14:47:11 -0600 Subject: [PATCH 1/5] Test against Enzyme --- docs/src/api.md | 3 ++- src/Turing.jl | 3 ++- test/ad.jl | 12 ++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0b8351eb3..604718b0e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -93,9 +93,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au | Exported symbol | Documentation | Description | |:----------------- |:------------------------------------ |:---------------------- | +| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | -| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | +| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index 1ff231017..a5d0a543c 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -23,7 +23,7 @@ using Printf: Printf using Random: Random using LinearAlgebra: I -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff() @@ -123,6 +123,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoMooncake, + AutoEnzyme, # Debugging - Turing setprogress!, # Distributions diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5..5d265ca27 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -20,6 +20,14 @@ if INCLUDE_MOONCAKE using Mooncake: Mooncake end +const INCLUDE_ENZYME = !IS_PRERELEASE + +if INCLUDE_ENZYME + import Pkg + Pkg.add("Enzyme") + using Enzyme: Enzyme +end + """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) @@ -193,6 +201,10 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)] if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end +if INCLUDE_ENZYME + push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward))) + push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse))) +end # Check that ADTypeCheckContext itself works as expected. @testset "ADTypeCheckContext" begin From 982427ebeaa66b61b6cba8bb2c95d92735ebf13f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Jul 2025 16:29:43 -0500 Subject: [PATCH 2/5] Update ad.jl --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 5d265ca27..581a54537 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -202,8 +202,8 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward))) - push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse))) + push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward))) + push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse))) end # Check that ADTypeCheckContext itself works as expected. From 9ac5b0e86d3dce9a30a3b99599fa7e9a76e7b67f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Jul 2025 17:32:08 -0500 Subject: [PATCH 3/5] Update ad.jl --- test/ad.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/ad.jl b/test/ad.jl index 581a54537..fb3e6c7f6 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -47,6 +47,9 @@ eltypes_by_adtype = Dict( if INCLUDE_MOONCAKE eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,) end +if INCLUDE_ENZYME + eltypes_by_adtype[AutoEnzyme] = () +end """ AbstractWrongADBackendError From d634e388fca4cc423667b963a674b6706cf7bcf2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 Aug 2025 23:01:17 +0100 Subject: [PATCH 4/5] Fix dictionary type --- test/ad.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index fb3e6c7f6..d4d5f6990 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -32,7 +32,7 @@ end const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) """A dictionary mapping ADTypes to the element types they use.""" -eltypes_by_adtype = Dict( +eltypes_by_adtype = Dict{Type,Tuple}( AutoForwardDiff => (ForwardDiff.Dual,), AutoReverseDiff => ( ReverseDiff.TrackedArray, @@ -205,8 +205,8 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward))) - push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse))) + push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))) + push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))) end # Check that ADTypeCheckContext itself works as expected. From ce08d69f1ed0a869c38bd793a63292f57d3904a2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 Aug 2025 23:02:06 +0100 Subject: [PATCH 5/5] mark function as const for good measure --- test/ad.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index d4d5f6990..bb787425d 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -205,8 +205,20 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))) - push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation=Enzyme.Const, + ), + ) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) end # Check that ADTypeCheckContext itself works as expected.