From cd9eda5f8053d738842c15140680688ba7a27537 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:46:08 +0200 Subject: [PATCH 01/14] Implement value_gradient_and_hessian --- DifferentiationInterface/docs/src/api.md | 2 + .../docs/src/operators.md | 18 +++--- .../src/DifferentiationInterface.jl | 1 + .../src/second_order/hessian.jl | 59 ++++++++++++++++++- .../src/tests/benchmark.jl | 32 ++++++---- .../src/tests/correctness.jl | 33 +++++++++++ .../src/tests/sparsity.jl | 6 ++ .../src/tests/type_stability.jl | 5 ++ 8 files changed, 135 insertions(+), 21 deletions(-) diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 61a75e8fa..c37435907 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -93,6 +93,8 @@ hvp! prepare_hessian hessian hessian! +value_gradient_and_hessian +value_gradient_and_hessian! ``` ## Utilities diff --git a/DifferentiationInterface/docs/src/operators.md b/DifferentiationInterface/docs/src/operators.md index 65d2703e4..dde8132c9 100644 --- a/DifferentiationInterface/docs/src/operators.md +++ b/DifferentiationInterface/docs/src/operators.md @@ -45,16 +45,16 @@ These operators are computed using the input `x` and a "seed" `v`, which lives e Several variants of each operator are defined. -| out-of-place | in-place | out-of-place + primal | in-place + primal | -| :-------------------------- | :--------------------------- | :----------------------------------------------- | :----------------------------------------------- | -| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) | +| out-of-place | in-place | out-of-place + primal | in-place + primal | +| :-------------------------- | :--------------------------- | :----------------------------------------------- | :------------------------------------------------ | +| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) | | [`second_derivative`](@ref) | [`second_derivative!`](@ref) | [`value_derivative_and_second_derivative`](@ref) | [`value_derivative_and_second_derivative!`](@ref) | -| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | -| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA | -| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | -| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | -| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | -| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA | +| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | +| [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) NA | +| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | +| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | +| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | +| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA | ## Mutation and signatures diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 73b04c890..b534a4663 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -99,6 +99,7 @@ export second_derivative!, second_derivative export value_derivative_and_second_derivative, value_derivative_and_second_derivative! export hvp!, hvp export hessian!, hessian +export value_gradient_and_hessian, value_gradient_and_hessian! export prepare_pushforward, prepare_pushforward_same_point export prepare_pullback, prepare_pullback_same_point diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 7d083297c..20e3421c9 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -24,6 +24,20 @@ Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. """ function hessian! end +""" + value_gradient_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) + +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`. +""" +function value_gradient_and_hessian end + +""" + value_gradient_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) + +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. +""" +function value_gradient_and_hessian! end + ## Preparation """ @@ -35,7 +49,7 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end -struct HVPHessianExtras{E<:HVPExtras} <: HessianExtras +struct HVPGradientHessianExtras{E<:HVPExtras} <: HessianExtras hvp_extras::E end @@ -46,7 +60,8 @@ end function prepare_hessian(f::F, backend::SecondOrder, x) where {F} v = basis(backend, x, first(CartesianIndices(x))) hvp_extras = prepare_hvp(f, backend, x, v) - return HVPHessianExtras(hvp_extras) + gradient_extras = prepare_gradient(f, inner(backend), x) + return HVPGradientHessianExtras(hvp_extras, gradient_extras) end ## One argument @@ -96,3 +111,43 @@ function hessian!( end return hess end + +function value_gradient_and_hessian( + f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) +) where {F} + return value_gradient_and_hessian(f, SecondOrder(backend, backend), x, extras) +end + +function value_gradient_and_hessian( + f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x) +) where {F} + y, grad = value_and_gradient(f, inner(backend), x, extras.gradient_extras) + hess = hessian(f, backend, x, extras) + return y, grad, hess +end + +function value_gradient_and_hessian!( + f::F, + grad, + hess, + backend::AbstractADType, + x, + extras::HessianExtras=prepare_hessian(f, backend, x), +) where {F} + return value_gradient_and_hessian!( + f, grad, hess, SecondOrder(backend, backend), x, extras + ) +end + +function value_gradient_and_hessian!( + f::F, + grad, + hess, + backend::SecondOrder, + x, + extras::HessianExtras=prepare_hessian(f, backend, x), +) where {F} + y, _ = value_and_gradient!(f, grad, inner(backend), x, extras.gradient_extras) + hessian!(f, hess, backend, extras) + return y, grad, hess +end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 4888ef5c5..61994231b 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -975,27 +975,31 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y) = deepcopy(scen) - @compat (; bench0, bench1, calls0, calls1) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_hessian(f, ba, x) bench0 = @be prepare_hessian(f, ba, x) samples = 1 evals = 1 bench1 = @be deepcopy(extras) hessian(f, ba, x, _) + bench2 = @be deepcopy(extras) value_gradient_and_hessian(f, ba, x, _) # count cc = CallCounter(f) extras = prepare_hessian(cc, ba, x) calls0 = reset_count!(cc) hessian(cc, ba, x, extras) calls1 = reset_count!(cc) - (; bench0, bench1, calls0, calls1) + value_gradient_and_hessian(cc, ba, x, extras) + calls2 = reset_count!(cc) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1 = failed_benchs(2) - calls0, calls1 = -1, -1 - (; bench0, bench1, calls0, calls1) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_hessian, bench0, calls0) record!(data, ba, scen, :hessian, bench1, calls1) + record!(data, ba, scen, :value_gradient_and_hessian, bench2, calls2) return nothing end @@ -1006,7 +1010,7 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y) = deepcopy(scen) - @compat (; bench0, bench1, calls0, calls1) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try hess_template = Matrix{typeof(y)}(undef, length(x), length(x)) # benchmark extras = prepare_hessian(f, ba, x) @@ -1014,21 +1018,29 @@ function run_benchmark!( bench1 = @be (hess=mysimilar(hess_template), ext=deepcopy(extras)) hessian!( f, _.hess, ba, x, _.ext ) evals = 1 + bench2 = @be ( + grad=mysimilar(x), hess=mysimilar(hess_template), ext=deepcopy(extras) + ) value_gradient_and_hessian!(f, _.grad, _.hess, ba, x, _.ext) evals = 1 # count cc = CallCounter(f) extras = prepare_hessian(cc, ba, x) calls0 = reset_count!(cc) hessian!(cc, mysimilar(hess_template), ba, x, extras) calls1 = reset_count!(cc) - (; bench0, bench1, calls0, calls1) + value_gradient_and_hessian!( + cc, mysimilar(x), mysimilar(hess_template), ba, x, extras + ) + calls2 = reset_count!(cc) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1 = failed_benchs(2) - calls0, calls1 = -1, -1 - (; bench0, bench1, calls0, calls1) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_hessian, bench0, calls0) record!(data, ba, scen, :hessian!, bench1, calls1) + record!(data, ba, scen, :value_gradient_and_hessian!, bench2, calls2) return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index b443a79e2..6f30b7b0d 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -972,6 +972,13 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) + grad_true = if ref_backend isa SecondOrder + gradient(f, inner(ref_backend), x) + elseif ref_backend isa AbstractADType + gradient(f, ref_backend, x) + else + new_scen.ref(x) + end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else @@ -979,13 +986,21 @@ function test_correctness( end hess1 = hessian(f, ba, x, extras) + y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin @test extras isa HessianExtras end + @testset "Primal value" begin + @test y2 ≈ y_true + end + @testset "Gradient value" begin + @test grad2 ≈ grad_true + end @testset "Hessian value" begin @test hess1 ≈ hess_true + @test hess2 ≈ hess_true end end test_scen_intact(new_scen, scen) @@ -1002,6 +1017,13 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) + grad_true = if ref_backend isa SecondOrder + gradient(f, inner(ref_backend), x) + elseif ref_backend isa AbstractADType + gradient(f, ref_backend, x) + else + new_scen.ref(x) + end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else @@ -1010,14 +1032,25 @@ function test_correctness( hess1_in = mysimilar(hess_true) hess1 = hessian!(f, hess1_in, ba, x, extras) + grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) + y2, grad2, hess2 = value_gradient_and_hessian!(f, grad2_in, hess2_in, ba, x, extras) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin @test extras isa HessianExtras end + @testset "Primal value" begin + @test y2 ≈ y_true + end + @testset "Gradient value" begin + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end @testset "Hessian value" begin @test hess1_in ≈ hess_true + @test hess2_in ≈ hess_true @test hess1 ≈ hess_true + @test hess2 ≈ hess_true end end test_scen_intact(new_scen, scen) diff --git a/DifferentiationInterfaceTest/src/tests/sparsity.jl b/DifferentiationInterfaceTest/src/tests/sparsity.jl index 492fbc734..f129015d9 100644 --- a/DifferentiationInterfaceTest/src/tests/sparsity.jl +++ b/DifferentiationInterfaceTest/src/tests/sparsity.jl @@ -99,9 +99,11 @@ function test_sparsity( end hess1 = hessian(f, ba, x, extras) + _, _, hess2 = value_gradient_and_hessian(f, ba, x, extras) @testset "Sparsity pattern" begin @test mynnz(hess1) == mynnz(hess_true) + @test mynnz(hess2) == mynnz(hess_true) end return nothing end @@ -116,9 +118,13 @@ function test_sparsity(ba::AbstractADType, scen::HessianScenario{1,:inplace}; re end hess1 = hessian!(f, mysimilar(hess_true), ba, x, extras) + _, _, hess2 = value_gradient_and_hessian!( + f, mysimilar(x), mysimilar(hess_true), ba, x, extras + ) @testset "Sparsity pattern" begin @test mynnz(hess1) == mynnz(hess_true) + @test mynnz(hess2) == mynnz(hess_true) end return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/type_stability.jl b/DifferentiationInterfaceTest/src/tests/type_stability.jl index 60e6358a1..e87ee6acd 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability.jl @@ -283,14 +283,19 @@ function test_jet(ba::AbstractADType, scen::HessianScenario{1,:outofplace}; ref_ extras = prepare_hessian(f, ba, x) JET.@test_opt function_filter = filt hessian(f, ba, x, extras) + JET.@test_opt function_filter = filt value_gradient_and_hessian(f, ba, x, extras) return nothing end function test_jet(ba::AbstractADType, scen::HessianScenario{1,:inplace}; ref_backend) @compat (; f, x, y) = deepcopy(scen) extras = prepare_hessian(f, ba, x) + grad_in = mysimilar(x) hess_in = Matrix{typeof(y)}(undef, length(x), length(x)) JET.@test_opt function_filter = filt hessian!(f, hess_in, ba, x, extras) + JET.@test_opt function_filter = filt value_gradient_and_hessian!( + f, grad_in, hess_in, ba, x, extras + ) return nothing end From bf239cb58ff891ef96f6c8e0b34f93d4ab64efd3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:49:50 +0200 Subject: [PATCH 02/14] Fix extras --- DifferentiationInterface/src/second_order/hessian.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 20e3421c9..b10cb5bf3 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -49,8 +49,9 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end -struct HVPGradientHessianExtras{E<:HVPExtras} <: HessianExtras - hvp_extras::E +struct HVPGradientHessianExtras{E1<:HVPExtras,E2<:GradientExtras} <: HessianExtras + hvp_extras::E1 + gradient_extras::E2 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} From 74394123ba46804c77a15e8326e1885ef8ed69d0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:06:35 +0200 Subject: [PATCH 03/14] Fix sparse --- .../src/second_order/hessian.jl | 47 ++----------------- .../src/second_order/second_order.jl | 5 ++ .../src/sparse/hessian.jl | 31 ++++++++++-- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index b10cb5bf3..7681959d1 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -49,19 +49,15 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end -struct HVPGradientHessianExtras{E1<:HVPExtras,E2<:GradientExtras} <: HessianExtras +struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras hvp_extras::E1 gradient_extras::E2 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} - return prepare_hessian(f, SecondOrder(backend, backend), x) -end - -function prepare_hessian(f::F, backend::SecondOrder, x) where {F} v = basis(backend, x, first(CartesianIndices(x))) hvp_extras = prepare_hvp(f, backend, x, v) - gradient_extras = prepare_gradient(f, inner(backend), x) + gradient_extras = prepare_gradient(f, maybe_inner(backend), x) return HVPGradientHessianExtras(hvp_extras, gradient_extras) end @@ -69,12 +65,6 @@ end function hessian( f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) -) where {F} - return hessian(f, SecondOrder(backend, backend), x, extras) -end - -function hessian( - f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x) ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -92,16 +82,6 @@ function hessian!( backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x), -) where {F} - return hessian!(f, hess, SecondOrder(backend, backend), x, extras) -end - -function hessian!( - f::F, - hess, - backend::SecondOrder, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -116,13 +96,7 @@ end function value_gradient_and_hessian( f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) ) where {F} - return value_gradient_and_hessian(f, SecondOrder(backend, backend), x, extras) -end - -function value_gradient_and_hessian( - f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x) -) where {F} - y, grad = value_and_gradient(f, inner(backend), x, extras.gradient_extras) + y, grad = value_and_gradient(f, maybe_inner(backend), x, extras.gradient_extras) hess = hessian(f, backend, x, extras) return y, grad, hess end @@ -135,20 +109,7 @@ function value_gradient_and_hessian!( x, extras::HessianExtras=prepare_hessian(f, backend, x), ) where {F} - return value_gradient_and_hessian!( - f, grad, hess, SecondOrder(backend, backend), x, extras - ) -end - -function value_gradient_and_hessian!( - f::F, - grad, - hess, - backend::SecondOrder, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), -) where {F} - y, _ = value_and_gradient!(f, grad, inner(backend), x, extras.gradient_extras) + y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) hessian!(f, hess, backend, extras) return y, grad, hess end diff --git a/DifferentiationInterface/src/second_order/second_order.jl b/DifferentiationInterface/src/second_order/second_order.jl index 5c9cae11f..2e3816198 100644 --- a/DifferentiationInterface/src/second_order/second_order.jl +++ b/DifferentiationInterface/src/second_order/second_order.jl @@ -54,3 +54,8 @@ Return a possibly modified `backend` that can work while nested inside another d At the moment, this is only useful for Enzyme, which needs `autodiff_deferred` to be compatible with higher-order differentiation. """ nested(backend::AbstractADType) = backend + +maybe_inner(backend::SecondOrder) = inner(backend) +maybe_outer(backend::SecondOrder) = outer(backend) +maybe_inner(backend::AbstractADType) = backend +maybe_outer(backend::AbstractADType) = backend diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index 222f773cc..96b2651b9 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -4,14 +4,16 @@ Base.@kwdef struct SparseHessianExtras{ K<:AbstractVector{<:Integer}, D<:AbstractVector, P<:AbstractVector, - E<:Extras, + E2<:HVPExtras, + E1<:GradientExtras, } <: HessianExtras sparsity::S compressed::C colors::K seeds::D products::P - hvp_extras::E + hvp_extras::E2 + gradient_extras::E1 end ## Hessian, one argument @@ -32,7 +34,10 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} similar(x) end compressed = stack(vec, products; dims=2) - return SparseHessianExtras(; sparsity, compressed, colors, seeds, products, hvp_extras) + gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x) + return SparseHessianExtras(; + sparsity, compressed, colors, seeds, products, hvp_extras, gradient_extras + ) end function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F} @@ -56,3 +61,23 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) wher end return decompress_symmetric(sparsity, compressed, colors) end + +function value_gradient_and_hessian!( + f::F, grad, hess, backend::AutoSparse, x, extras::SparseHessianExtras +) where {F} + y, _ = value_and_gradient!( + f, grad, maybe_inner(dense_ad(backend)), x, extras.gradient_extras + ) + hessian!(f, hess, backend, x, extras) + return y, grad, hess +end + +function value_gradient_and_hessian( + f::F, backend::AutoSparse, x, extras::SparseHessianExtras +) where {F} + y, grad = value_and_gradient( + f, maybe_inner(dense_ad(backend)), x, extras.gradient_extras + ) + hess = hessian(f, hess, backend, x, extras) + return y, grad, hess +end From f3549c0fe0c9bd501504607394df9cd5970ac10c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:12:30 +0200 Subject: [PATCH 04/14] Typo --- DifferentiationInterface/src/second_order/hessian.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 7681959d1..cebfde88e 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -50,8 +50,8 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras - hvp_extras::E1 - gradient_extras::E2 + hvp_extras::E2 + gradient_extras::E1 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} From 94154d39e81567c42b04250e3486d29f0862a6e3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 06:58:26 +0200 Subject: [PATCH 05/14] Implement in extensions --- .../onearg.jl | 112 ++++++++++-------- .../twoarg.jl | 16 +-- .../onearg.jl | 34 ++++-- .../onearg.jl | 37 +++++- .../onearg.jl | 26 +++- .../onearg.jl | 25 ++++ .../onearg.jl | 51 ++++++-- .../twoarg.jl | 12 +- .../DifferentiationInterfaceZygoteExt.jl | 14 +++ 9 files changed, 234 insertions(+), 93 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 084c42dd4..34ef21057 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,9 +1,9 @@ ## Pushforward -struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras +struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1!} <: PushforwardExtras y_prototype::Y jvp_exe::E1 - jvp_exe!::E2 + jvp_exe!::E1! end function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx) @@ -70,9 +70,9 @@ end ## Pullback -struct FastDifferentiationOneArgPullbackExtras{E1,E2} <: PullbackExtras +struct FastDifferentiationOneArgPullbackExtras{E1,E1!} <: PullbackExtras vjp_exe::E1 - vjp_exe!::E2 + vjp_exe!::E1! end function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, dy) @@ -133,10 +133,10 @@ end ## Derivative -struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras +struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1!} <: DerivativeExtras y_prototype::Y der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f, ::AutoFastDifferentiation, x) @@ -190,13 +190,12 @@ end ## Gradient -struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras +struct FastDifferentiationOneArgGradientExtras{E1,E1!} <: GradientExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x) - y_prototype = f(x) x_var = make_variables(:x, size(x)...) y_var = f(x_var) @@ -241,10 +240,10 @@ end ## Jacobian -struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2} <: JacobianExtras +struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1!} <: JacobianExtras y_prototype::Y jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( @@ -307,16 +306,15 @@ end ## Second derivative -struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1!,E2,E2!} <: +struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D,E2,E2!} <: SecondDerivativeExtras y_prototype::Y - der_exe::E1 - der_exe!::E1! + derivative_extras::D der2_exe::E2 der2_exe!::E2! end -function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x) +function DI.prepare_second_derivative(f, backend::AutoFastDifferentiation, x) y_prototype = f(x) x_var = only(make_variables(:x)) y_var = f(x_var) @@ -324,17 +322,13 @@ function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x) x_vec_var = monovec(x_var) y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) - der_vec_var = derivative(y_vec_var, x_var) der2_vec_var = derivative(y_vec_var, x_var, x_var) - - der_exe = make_function(der_vec_var, x_vec_var; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var; in_place=true) - der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false) der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true) + derivative_extras = DI.prepare_derivative(f, backend, x) return FastDifferentiationAllocatingSecondDerivativeExtras( - y_prototype, der_exe, der_exe!, der2_exe, der2_exe! + y_prototype, derivative_extras, der2_exe, der2_exe! ) end @@ -364,20 +358,13 @@ end function DI.value_derivative_and_second_derivative( f, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) - y = f(x) - if extras.y_prototype isa Number - der = only(extras.der_exe(monovec(x))) - der2 = only(extras.der2_exe(monovec(x))) - return y, der, der2 - else - der = reshape(extras.der_exe(monovec(x)), size(extras.y_prototype)) - der2 = reshape(extras.der2_exe(monovec(x)), size(extras.y_prototype)) - return y, der, der2 - end + y, der = DI.value_and_derivative(f, backend, x, extras.derivative_extras) + der2 = DI.second_derivative(f, backend, x, extras) + return y, der, der2 end function DI.value_derivative_and_second_derivative!( @@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!( x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) - y = f(x) - extras.der_exe!(vec(der), monovec(x)) - extras.der2_exe!(vec(der2), monovec(x)) + y, _ = DI.value_and_derivative!(f, der, backend, x, extras.derivative_extras) + DI.second_derivative!(f, der2, backend, x, extras) return y, der, der2 end ## HVP -struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras - hvp_exe::E1 - hvp_exe!::E2 +struct FastDifferentiationHVPExtras{E2,E2!} <: HVPExtras + hvp_exe::E2 + hvp_exe!::E2! end function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v) @@ -428,24 +414,30 @@ end ## Hessian -struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras - hess_exe::E1 - hess_exe!::E2 +struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras + grad_extras::G + hess_exe::E2 + hess_exe!::E2! end function DI.prepare_hessian( f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x ) - x_vec_var = make_variables(:x, size(x)...) - y_vec_var = f(x_vec_var) + x_var = make_variables(:x, size(x)...) + y_var = f(x_vec_var) + + x_vec_var = vec(x_var) + hess_var = if backend isa AutoSparse - sparse_hessian(y_vec_var, vec(x_vec_var)) + sparse_hessian(y_var, x_vec_var) else - hessian(y_vec_var, vec(x_vec_var)) + hessian(y_var, x_vec_var) end - hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false) - hess_exe! = make_function(hess_var, vec(x_vec_var); in_place=true) - return FastDifferentiationHessianExtras(hess_exe, hess_exe!) + hess_exe = make_function(hess_var, x_vec_var; in_place=false) + hess_exe! = make_function(hess_var, x_vec_var; in_place=true) + + gradient_extras = DI.prepare_gradient(f, backend, x) + return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!) end function DI.hessian( @@ -467,3 +459,27 @@ function DI.hessian!( extras.hess_exe!(hess, vec(x)) return hess end + +function DI.value_gradient_and_hessian( + f, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + extras::FastDifferentiationHessianExtras, +) + y, grad = DI.value_and_gradient(f, backend, x, extras.gradient_extras) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, + grad, + hess, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + extras::FastDifferentiationHessianExtras, +) + y, _ = DI.value_and_gradient!(f, grad, backend, x, extras.gradient_extras) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index b3141f506..5d7059de8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct FastDifferentiationTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras jvp_exe::E1 - jvp_exe!::E2 + jvp_exe!::E1! end function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx) @@ -80,9 +80,9 @@ end ## Pullback -struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras +struct FastDifferentiationTwoArgPullbackExtras{E1,E1!} <: PullbackExtras vjp_exe::E1 - vjp_exe!::E2 + vjp_exe!::E1! end function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, dy) @@ -156,9 +156,9 @@ end ## Derivative -struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct FastDifferentiationTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x) @@ -216,9 +216,9 @@ end ## Jacobian -struct FastDifferentiationTwoArgJacobianExtras{E1,E2} <: JacobianExtras +struct FastDifferentiationTwoArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 8f30c94be..96a9e2f00 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -153,21 +153,39 @@ end ## Hessian -struct FiniteDiffHessianExtras{C} <: HessianExtras - cache::C +struct FiniteDiffHessianExtras{C1,C2} <: HessianExtras + gradient_cache::C1 + hessian_cache::C2 end function DI.prepare_hessian(f, backend::AutoFiniteDiff, x) - cache = HessianCache(x, fdhtype(backend)) - return FiniteDiffHessianExtras(cache) + y = f(x) + df = zero(y) .* x + gradient_cache = GradientCache(df, x, fdtype(backend)) + hessian_cache = HessianCache(x, fdhtype(backend)) + return FiniteDiffHessianExtras(gradient_cache, hessian_cache) end -# cache cannot be reused because of https://github.com/JuliaDiff/FiniteDiff.jl/issues/185 - function DI.hessian(f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras) - return finite_difference_hessian(f, x, extras.cache) + return finite_difference_hessian(f, x, extras.hessian_cache) end function DI.hessian!(f, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras) - return finite_difference_hessian!(hess, f, x, extras.cache) + return finite_difference_hessian!(hess, f, x, extras.hessian_cache) +end + +function DI.value_gradient_and_hessian( + f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras +) + grad = finite_difference_gradient(f, x, extras.gradient_cache) + hess = finite_difference_hessian(f, x, extras.hessian_cache) + return f(x), grad, hess +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras +) + finite_difference_gradient!(grad, f, x, extras.gradient_cache) + finite_difference_hessian!(hess, f, x, extras.hessian_cache) + return f(x), grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 6326cbc4f..34bc99ad9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -137,20 +137,47 @@ end ## Hessian -struct ForwardDiffHessianExtras{C} <: HessianExtras - config::C +struct ForwardDiffHessianExtras{C1,C2} <: HessianExtras + array_config::C1 + result_config::C2 end function DI.prepare_hessian(f, backend::AutoForwardDiff, x) - return ForwardDiffHessianExtras(HessianConfig(f, x, choose_chunk(backend, x))) + example_result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) + chunk = choose_chunk(backend, x) + array_config = HessianConfig(f, x, chunk) + result_config = HessianConfig(f, example_result, x, chunk) + return ForwardDiffHessianExtras(array_config, result_config) end function DI.hessian!( f::F, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras ) where {F} - return hessian!(hess, f, x, extras.config) + return hessian!(hess, f, x, extras.array_config) end function DI.hessian(f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras) where {F} - return hessian(f, x, extras.config) + return hessian(f, x, extras.array_config) +end + +function DI.value_gradient_and_hessian!( + f::F, grad, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras +) where {F} + result = MutableDiffResult(y, (grad, hess)) + result = hessian!(result, f, x, extras.config) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end + +function DI.value_gradient_and_hessian( + f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras +) where {F} + result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = hessian!(result, f, x, extras.config) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 672bdb963..99cfa2f2a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -42,9 +42,9 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f, dy, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras ) - return DI.value_and_derivative!(f, dy, single_threaded(backend), x, extras) + return DI.value_and_derivative!(f, der, single_threaded(backend), x, extras) end function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras) @@ -52,9 +52,9 @@ function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::Derivati end function DI.derivative!( - f, dy, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras ) - return DI.derivative!(f, dy, single_threaded(backend), x, extras) + return DI.derivative!(f, der, single_threaded(backend), x, extras) end ## Gradient @@ -149,6 +149,20 @@ function DI.hessian(f, backend::AutoPolyesterForwardDiff, x, extras::HessianExtr return DI.hessian(f, single_threaded(backend), x, extras) end -function DI.hessian!(f, dy, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras) - return DI.hessian!(f, dy, single_threaded(backend), x, extras) +function DI.hessian!(f, hess, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras) + return DI.hessian!(f, hess, single_threaded(backend), x, extras) +end + +function DI.value_gradient_and_hessian( + f, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras +) + return DI.value_gradient_and_hessian(f, single_threaded(backend), x, extras) +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras +) + return DI.value_gradient_and_hessian!( + f, grad, hess, single_threaded(backend), x, extras + ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index c940db730..4780b2b1c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -163,3 +163,28 @@ function DI.hessian( ) return hessian!(extras.tape, x) end + +function DI.value_gradient_and_hessian!( + _f, + grad::AbstractVector, + hess::AbstractMatrix, + ::AutoReverseDiff, + x::AbstractArray, + extras::ReverseDiffHessianExtras, +) + result = MutableDiffResult(y, (grad, hess)) + result = hessian!(result, extras.tape, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end + +function DI.value_gradient_and_hessian( + _f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras +) + result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = hessian!(result, extras.tape, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index da809b60f..86c815bd1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct SymbolicsOneArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct SymbolicsOneArgPushforwardExtras{E1,E1!} <: PushforwardExtras pf_exe::E1 - pf_exe!::E2 + pf_exe!::E1! end function DI.prepare_pushforward(f, ::AutoSymbolics, x, dx) @@ -57,9 +57,9 @@ end ## Derivative -struct SymbolicsOneArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct SymbolicsOneArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f, ::AutoSymbolics, x) @@ -98,9 +98,9 @@ end ## Gradient -struct SymbolicsOneArgGradientExtras{E1,E2} <: GradientExtras +struct SymbolicsOneArgGradientExtras{E1,E1!} <: GradientExtras grad_exe::E1 - grad_exe!::E2 + grad_exe!::E1! end function DI.prepare_gradient(f, ::AutoSymbolics, x) @@ -136,9 +136,9 @@ end ## Jacobian -struct SymbolicsOneArgJacobianExtras{E1,E2} <: JacobianExtras +struct SymbolicsOneArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( @@ -197,9 +197,10 @@ end ## Hessian -struct SymbolicsOneArgHessianExtras{E1,E2} <: HessianExtras - hess_exe::E1 - hess_exe!::E2 +struct SymbolicsOneArgHessianExtras{G,E2,E2!} <: HessianExtras + gradient_extras::G + hess_exe::E2 + hess_exe!::E2! end function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x) @@ -213,7 +214,9 @@ function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSym res = build_function(hess_var, vec(x_var); expression=Val(false)) (hess_exe, hess_exe!) = res - return SymbolicsOneArgHessianExtras(hess_exe, hess_exe!) + + gradient_extras = DI.prepare_gradient(f, backend, x) + return SymbolicsOneArgHessianExtras(gradient_extras, hess_exe, hess_exe!) end function DI.hessian( @@ -235,3 +238,27 @@ function DI.hessian!( extras.hess_exe!(hess, vec(x)) return hess end + +function DI.value_gradient_and_hessian( + f, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + extras::SymbolicsOneArgHessianExtras, +) + y, grad = DI.value_and_gradient(f, backend, x, extras.gradient_extras) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, + grad, + hess, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + extras::SymbolicsOneArgHessianExtras, +) + y, _ = DI.value_and_gradient!(f, grad, backend, x, extras.gradient_extras) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 14ee93f96..a6e37f6df 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct SymbolicsTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct SymbolicsTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras pushforward_exe::E1 - pushforward_exe!::E2 + pushforward_exe!::E1! end function DI.prepare_pushforward(f!, y, ::AutoSymbolics, x, dx) @@ -61,9 +61,9 @@ end ## Derivative -struct SymbolicsTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct SymbolicsTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f!, y, ::AutoSymbolics, x) @@ -106,9 +106,9 @@ end ## Jacobian -struct SymbolicsTwoArgJacobianExtras{E1,E2} <: JacobianExtras +struct SymbolicsTwoArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 04d375850..49e54eb8d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -131,4 +131,18 @@ function DI.hessian!(f, hess, backend::AutoZygote, x, extras::NoHessianExtras) return copyto!(hess, DI.hessian(f, backend, x, extras)) end +function DI.value_gradient_and_hessian(f, ::AutoZygote, x, ::NoHessianExtras) + y, grad = DI.value_and_gradient(f, backend, x, NoGradientExtras()) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoZygote, x, extras::NoHessianExtras +) + y, _ = DI.value_and_gradient!(f, backend, x, NoGradientExtras()) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end + end From 022bc2c1dc9d088104a1e796278f2c2a9d4a501a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:07:16 +0200 Subject: [PATCH 06/14] Fixes --- .../DifferentiationInterfaceFastDifferentiationExt.jl | 2 +- .../onearg.jl | 8 ++++---- .../ext/DifferentiationInterfaceReverseDiffExt/onearg.jl | 2 +- .../DifferentiationInterfaceSymbolicsExt.jl | 2 +- .../ext/DifferentiationInterfaceSymbolicsExt/onearg.jl | 6 +++--- DifferentiationInterface/src/sparse/hessian.jl | 2 +- DifferentiationInterface/test/runtests.jl | 2 +- DifferentiationInterfaceTest/src/tests/correctness.jl | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 499592b03..2088495fc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceFastDifferentiationExt -using ADTypes: ADTypes, AutoFastDifferentiation, AutoSparse +using ADTypes: ADTypes, AutoFastDifferentiation, AutoSparse, dense_ad import DifferentiationInterface as DI using DifferentiationInterface: DerivativeExtras, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 34ef21057..2c63dd2fe 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -424,7 +424,7 @@ function DI.prepare_hessian( f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x ) x_var = make_variables(:x, size(x)...) - y_var = f(x_vec_var) + y_var = f(x_var) x_vec_var = vec(x_var) @@ -436,7 +436,7 @@ function DI.prepare_hessian( hess_exe = make_function(hess_var, x_vec_var; in_place=false) hess_exe! = make_function(hess_var, x_vec_var; in_place=true) - gradient_extras = DI.prepare_gradient(f, backend, x) + gradient_extras = DI.prepare_gradient(f, dense_ad(backend), x) return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!) end @@ -466,7 +466,7 @@ function DI.value_gradient_and_hessian( x, extras::FastDifferentiationHessianExtras, ) - y, grad = DI.value_and_gradient(f, backend, x, extras.gradient_extras) + y, grad = DI.value_and_gradient(f, dense_ad(backend), x, extras.gradient_extras) hess = DI.hessian(f, backend, x, extras) return y, grad, hess end @@ -479,7 +479,7 @@ function DI.value_gradient_and_hessian!( x, extras::FastDifferentiationHessianExtras, ) - y, _ = DI.value_and_gradient!(f, grad, backend, x, extras.gradient_extras) + y, _ = DI.value_and_gradient!(f, grad, dense_ad(backend), x, extras.gradient_extras) DI.hessian!(f, hess, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 4780b2b1c..ad7ee8d10 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -166,7 +166,7 @@ end function DI.value_gradient_and_hessian!( _f, - grad::AbstractVector, + grad, hess::AbstractMatrix, ::AutoReverseDiff, x::AbstractArray, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 0074c6d79..89019aed2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceSymbolicsExt -using ADTypes: ADTypes, AutoSymbolics, AutoSparse +using ADTypes: ADTypes, AutoSymbolics, AutoSparse, dense_ad import DifferentiationInterface as DI using DifferentiationInterface: DerivativeExtras, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 86c815bd1..f5712d4df 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -215,7 +215,7 @@ function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSym res = build_function(hess_var, vec(x_var); expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_extras = DI.prepare_gradient(f, backend, x) + gradient_extras = DI.prepare_gradient(f, dense_ad(backend), x) return SymbolicsOneArgHessianExtras(gradient_extras, hess_exe, hess_exe!) end @@ -245,7 +245,7 @@ function DI.value_gradient_and_hessian( x, extras::SymbolicsOneArgHessianExtras, ) - y, grad = DI.value_and_gradient(f, backend, x, extras.gradient_extras) + y, grad = DI.value_and_gradient(f, dense_ad(backend), x, extras.gradient_extras) hess = DI.hessian(f, backend, x, extras) return y, grad, hess end @@ -258,7 +258,7 @@ function DI.value_gradient_and_hessian!( x, extras::SymbolicsOneArgHessianExtras, ) - y, _ = DI.value_and_gradient!(f, grad, backend, x, extras.gradient_extras) + y, _ = DI.value_and_gradient!(f, grad, dense_ad(backend), x, extras.gradient_extras) DI.hessian!(f, hess, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index 96b2651b9..074c6fdf0 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -78,6 +78,6 @@ function value_gradient_and_hessian( y, grad = value_and_gradient( f, maybe_inner(dense_ad(backend)), x, extras.gradient_extras ) - hess = hessian(f, hess, backend, x, extras) + hess = hessian(f, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index 89211b9d9..a4910e175 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -33,7 +33,7 @@ ALL_BACKENDS = [ @testset verbose = true "DifferentiationInterface.jl" begin if GROUP == "Formalities" || GROUP == "All" @testset "Formalities/$file" for file in readdir(joinpath(@__DIR__, "Formalities")) - @info "Testing Formalities/$file)" + @info "Testing Formalities/$file" include(joinpath(@__DIR__, "Formalities", file)) end end diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 6f30b7b0d..4005d9919 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -1040,7 +1040,7 @@ function test_correctness( @test extras isa HessianExtras end @testset "Primal value" begin - @test y2 ≈ y_true + @test y2 ≈ y end @testset "Gradient value" begin @test grad2_in ≈ grad_true From 848456104b43c4e54bf2fa104f3fec2fc1b585c1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:09:32 +0200 Subject: [PATCH 07/14] Fix --- DifferentiationInterfaceTest/src/tests/correctness.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 4005d9919..68d7ab23b 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -993,7 +993,7 @@ function test_correctness( @test extras isa HessianExtras end @testset "Primal value" begin - @test y2 ≈ y_true + @test y2 ≈ y end @testset "Gradient value" begin @test grad2 ≈ grad_true From 44d9f429cf26561a41fe509103e79a3f96881f53 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:26:39 +0200 Subject: [PATCH 08/14] Fix ref --- .../DifferentiationInterfaceZygoteExt.jl | 2 +- .../src/DifferentiationInterface.jl | 1 + .../src/second_order/second_order.jl | 5 ---- DifferentiationInterface/src/utils/maybe.jl | 7 ++++++ .../src/DifferentiationInterfaceTest.jl | 2 ++ .../src/tests/correctness.jl | 24 +++++++------------ 6 files changed, 19 insertions(+), 22 deletions(-) create mode 100644 DifferentiationInterface/src/utils/maybe.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 49e54eb8d..e393633a5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -140,7 +140,7 @@ end function DI.value_gradient_and_hessian!( f, grad, hess, backend::AutoZygote, x, extras::NoHessianExtras ) - y, _ = DI.value_and_gradient!(f, backend, x, NoGradientExtras()) + y, _ = DI.value_and_gradient!(f, grad, backend, x, NoGradientExtras()) DI.hessian!(f, hess, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index b534a4663..504171bbe 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -54,6 +54,7 @@ include("utils/printing.jl") include("utils/chunk.jl") include("utils/check.jl") include("utils/exceptions.jl") +include("utils/maybe.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") diff --git a/DifferentiationInterface/src/second_order/second_order.jl b/DifferentiationInterface/src/second_order/second_order.jl index 2e3816198..5c9cae11f 100644 --- a/DifferentiationInterface/src/second_order/second_order.jl +++ b/DifferentiationInterface/src/second_order/second_order.jl @@ -54,8 +54,3 @@ Return a possibly modified `backend` that can work while nested inside another d At the moment, this is only useful for Enzyme, which needs `autodiff_deferred` to be compatible with higher-order differentiation. """ nested(backend::AbstractADType) = backend - -maybe_inner(backend::SecondOrder) = inner(backend) -maybe_outer(backend::SecondOrder) = outer(backend) -maybe_inner(backend::AbstractADType) = backend -maybe_outer(backend::AbstractADType) = backend diff --git a/DifferentiationInterface/src/utils/maybe.jl b/DifferentiationInterface/src/utils/maybe.jl new file mode 100644 index 000000000..6b22961f9 --- /dev/null +++ b/DifferentiationInterface/src/utils/maybe.jl @@ -0,0 +1,7 @@ +maybe_inner(backend::SecondOrder) = inner(backend) +maybe_outer(backend::SecondOrder) = outer(backend) +maybe_inner(backend::AbstractADType) = backend +maybe_outer(backend::AbstractADType) = backend + +maybe_dense_ad(backend::AutoSparse) = dense_ad(backend) +maybe_dense_ad(backend::AbstractADType) = backend diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index f722f31e9..a8686ddc0 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -24,6 +24,8 @@ using DifferentiationInterface using DifferentiationInterface: backend_str, inner, + maybe_inner, + maybe_dense_ad, mode, outer, twoarg_support, diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 68d7ab23b..637c50e28 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -794,10 +794,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_second_derivative(f, ba, mycopy_random(x)) - der1_true = if ref_backend isa SecondOrder - derivative(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - derivative(f, ref_backend, x) + der1_true = if ref_backend isa AbstractADType + derivative(f, maybe_inner(ref_backend), x) else new_scen.first_order_ref(x) end @@ -839,10 +837,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_second_derivative(f, ba, mycopy_random(x)) - der1_true = if ref_backend isa SecondOrder - derivative(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - derivative(f, ref_backend, x) + der1_true = if ref_backend isa AbstractADType + derivative(f, maybe_inner(ref_backend), x) else new_scen.first_order_ref(x) end @@ -972,10 +968,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) - grad_true = if ref_backend isa SecondOrder - gradient(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - gradient(f, ref_backend, x) + grad_true = if ref_backend isa AbstractADType + gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else new_scen.ref(x) end @@ -1017,10 +1011,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) - grad_true = if ref_backend isa SecondOrder - gradient(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - gradient(f, ref_backend, x) + grad_true = if ref_backend isa AbstractADType + gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else new_scen.ref(x) end From b6a190ad838f1a67262d4421f4370cfc6bd55c2d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:28:58 +0200 Subject: [PATCH 09/14] Fix --- DifferentiationInterface/src/second_order/hessian.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index cebfde88e..267492eb1 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -110,6 +110,6 @@ function value_gradient_and_hessian!( extras::HessianExtras=prepare_hessian(f, backend, x), ) where {F} y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) - hessian!(f, hess, backend, extras) + hessian!(f, hess, backend, x, extras) return y, grad, hess end From 7666245e9f1aff15bebe8462f07f243a9bfc55e4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:38:15 +0200 Subject: [PATCH 10/14] Fixes --- .../DifferentiationInterfaceFastDifferentiationExt.jl | 5 +++-- .../onearg.jl | 8 +++++--- .../ext/DifferentiationInterfaceReverseDiffExt/onearg.jl | 6 ++++-- .../DifferentiationInterfaceSymbolicsExt.jl | 5 +++-- .../ext/DifferentiationInterfaceSymbolicsExt/onearg.jl | 8 +++++--- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 2088495fc..404547242 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceFastDifferentiationExt -using ADTypes: ADTypes, AutoFastDifferentiation, AutoSparse, dense_ad +using ADTypes: ADTypes, AutoFastDifferentiation, AutoSparse import DifferentiationInterface as DI using DifferentiationInterface: DerivativeExtras, @@ -10,7 +10,8 @@ using DifferentiationInterface: JacobianExtras, PullbackExtras, PushforwardExtras, - SecondDerivativeExtras + SecondDerivativeExtras, + maybe_dense_ad using FastDifferentiation: derivative, hessian, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 2c63dd2fe..eeb845fad 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -436,7 +436,7 @@ function DI.prepare_hessian( hess_exe = make_function(hess_var, x_vec_var; in_place=false) hess_exe! = make_function(hess_var, x_vec_var; in_place=true) - gradient_extras = DI.prepare_gradient(f, dense_ad(backend), x) + gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x) return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!) end @@ -466,7 +466,7 @@ function DI.value_gradient_and_hessian( x, extras::FastDifferentiationHessianExtras, ) - y, grad = DI.value_and_gradient(f, dense_ad(backend), x, extras.gradient_extras) + y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras) hess = DI.hessian(f, backend, x, extras) return y, grad, hess end @@ -479,7 +479,9 @@ function DI.value_gradient_and_hessian!( x, extras::FastDifferentiationHessianExtras, ) - y, _ = DI.value_and_gradient!(f, grad, dense_ad(backend), x, extras.gradient_extras) + y, _ = DI.value_and_gradient!( + f, grad, maybe_dense_ad(backend), x, extras.gradient_extras + ) DI.hessian!(f, hess, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index ad7ee8d10..89d091b51 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -172,7 +172,7 @@ function DI.value_gradient_and_hessian!( x::AbstractArray, extras::ReverseDiffHessianExtras, ) - result = MutableDiffResult(y, (grad, hess)) + result = MutableDiffResult(one(eltype(x)), (grad, hess)) result = hessian!(result, extras.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) @@ -182,7 +182,9 @@ end function DI.value_gradient_and_hessian( _f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras ) - result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) result = hessian!(result, extras.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 89019aed2..b99cd1d24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceSymbolicsExt -using ADTypes: ADTypes, AutoSymbolics, AutoSparse, dense_ad +using ADTypes: ADTypes, AutoSymbolics, AutoSparse import DifferentiationInterface as DI using DifferentiationInterface: DerivativeExtras, @@ -10,7 +10,8 @@ using DifferentiationInterface: JacobianExtras, PullbackExtras, PushforwardExtras, - SecondDerivativeExtras + SecondDerivativeExtras, + maybe_dense_ad using FillArrays: Fill using LinearAlgebra: dot using Symbolics: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index f5712d4df..a816bfddc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -215,7 +215,7 @@ function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSym res = build_function(hess_var, vec(x_var); expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_extras = DI.prepare_gradient(f, dense_ad(backend), x) + gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x) return SymbolicsOneArgHessianExtras(gradient_extras, hess_exe, hess_exe!) end @@ -245,7 +245,7 @@ function DI.value_gradient_and_hessian( x, extras::SymbolicsOneArgHessianExtras, ) - y, grad = DI.value_and_gradient(f, dense_ad(backend), x, extras.gradient_extras) + y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras) hess = DI.hessian(f, backend, x, extras) return y, grad, hess end @@ -258,7 +258,9 @@ function DI.value_gradient_and_hessian!( x, extras::SymbolicsOneArgHessianExtras, ) - y, _ = DI.value_and_gradient!(f, grad, dense_ad(backend), x, extras.gradient_extras) + y, _ = DI.value_and_gradient!( + f, grad, maybe_dense_ad(backend), x, extras.gradient_extras + ) DI.hessian!(f, hess, backend, x, extras) return y, grad, hess end From 428235eb4995b19f36bce5dadc4abb1f1bdb7bf9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:41:09 +0200 Subject: [PATCH 11/14] Fix --- DifferentiationInterfaceTest/src/tests/correctness.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 637c50e28..e83ebadc5 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -971,7 +971,7 @@ function test_correctness( grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else - new_scen.ref(x) + new_scen.first_order_ref(x) end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) @@ -1014,7 +1014,7 @@ function test_correctness( grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else - new_scen.ref(x) + new_scen.first_order_ref(x) end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) From 23f1f605e23a649a5c65a4ff94a18ac72de0176c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:54:17 +0200 Subject: [PATCH 12/14] Typos --- .../onearg.jl | 2 +- .../ext/DifferentiationInterfaceForwardDiffExt/onearg.jl | 6 ++++-- .../DifferentiationInterfaceZygoteExt.jl | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index eeb845fad..74b3ca8fb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -415,7 +415,7 @@ end ## Hessian struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras - grad_extras::G + gradient_extras::G hess_exe::E2 hess_exe!::E2! end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 34bc99ad9..09457f34c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -165,7 +165,7 @@ end function DI.value_gradient_and_hessian!( f::F, grad, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras ) where {F} - result = MutableDiffResult(y, (grad, hess)) + result = MutableDiffResult(one(eltype(x)), (grad, hess)) result = hessian!(result, f, x, extras.config) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) @@ -175,7 +175,9 @@ end function DI.value_gradient_and_hessian( f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras ) where {F} - result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) result = hessian!(result, f, x, extras.config) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index e393633a5..d7c62fe5c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -131,7 +131,7 @@ function DI.hessian!(f, hess, backend::AutoZygote, x, extras::NoHessianExtras) return copyto!(hess, DI.hessian(f, backend, x, extras)) end -function DI.value_gradient_and_hessian(f, ::AutoZygote, x, ::NoHessianExtras) +function DI.value_gradient_and_hessian(f, backend::AutoZygote, x, ::NoHessianExtras) y, grad = DI.value_and_gradient(f, backend, x, NoGradientExtras()) hess = DI.hessian(f, backend, x, extras) return y, grad, hess From a40eed6637a6f03a83dc14803f1daba9bf661ecc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:59:09 +0200 Subject: [PATCH 13/14] Typo --- .../ext/DifferentiationInterfaceForwardDiffExt/onearg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 09457f34c..ad4aa0c61 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -166,7 +166,7 @@ function DI.value_gradient_and_hessian!( f::F, grad, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras ) where {F} result = MutableDiffResult(one(eltype(x)), (grad, hess)) - result = hessian!(result, f, x, extras.config) + result = hessian!(result, f, x, extras.result_config) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) ) @@ -178,7 +178,7 @@ function DI.value_gradient_and_hessian( result = MutableDiffResult( one(eltype(x)), (similar(x), similar(x, length(x), length(x))) ) - result = hessian!(result, f, x, extras.config) + result = hessian!(result, f, x, extras.result_config) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) ) From de8cbb740dbec44280ed4e2591c23355252d47d7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 6 Jun 2024 08:08:35 +0200 Subject: [PATCH 14/14] Typo --- .../DifferentiationInterfaceZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index d7c62fe5c..2ce4ba6cb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -131,7 +131,7 @@ function DI.hessian!(f, hess, backend::AutoZygote, x, extras::NoHessianExtras) return copyto!(hess, DI.hessian(f, backend, x, extras)) end -function DI.value_gradient_and_hessian(f, backend::AutoZygote, x, ::NoHessianExtras) +function DI.value_gradient_and_hessian(f, backend::AutoZygote, x, extras::NoHessianExtras) y, grad = DI.value_and_gradient(f, backend, x, NoGradientExtras()) hess = DI.hessian(f, backend, x, extras) return y, grad, hess