Open
Description
I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.
Following FluxML/Zygote.jl#115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array
, but the following appears to have issues inferring the type of Dual
.
Any help would be greatly appreciated! :)
# Zygote v0.4.1, Flux v0.10.0, ForwardDiff v0.10.7, DiffRules 0.1.0, ZygoteRules 0.2.0
# Julia Version 1.3.0
# Commit 46ce4d7933 (2019-11-26 06:09 UTC)
# Platform Info:
# OS: Linux (x86_64-pc-linux-gnu)
# CPU: Intel(R) Core(TM) i9-7960X CPU @ 2.80GHz
# WORD_SIZE: 64
# LIBM: libopenlibm
# LLVM: libLLVM-6.0.1 (ORCJIT, skylake)
using Flux, ForwardDiff, Zygote
using LinearAlgebra
# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
meanNN::M
logstd::L
end
Flux.@functor DiagGaussianPolicy
(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)
# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
meanact = P(feature)
ll = -length(P.logstd) * log(2pi) / 2
for i = 1:length(action)
ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
ll -= P.logstd[i]
end
ll
end
function flatgrad(f, ps)
gs = Zygote.gradient(f, ps)
vcat([vec(gs[p]) for p in ps]...)
end
Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32
function hessian_vector_product(f,ps,v)
g = let f=f
ps -> flatgrad(f, ps)::Vector{Float32}
end
gvp = let g=g, v=v
ps -> (g(ps)⋅v)::Vector{Float32}
end
Zygote.forward_jacobian(gvp, ps)[2]
end
function test()
policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
ps = Flux.params(policy)
v = rand(Float32, sum(length, ps))
feat = rand(Float32, 4)
act = rand(Float32, 2)
f = let policy=policy, feat=feat, act=act
() -> loglikelihood(policy, feat, act)
end
hessian_vector_product(f, ps, v)
end
Calling test()
yields:
an_dual.
Stacktrace:
[1] throw_cannot_dual(::Type) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:36
[2] ForwardDiff.Dual{Nothing,Any,12}(::Array{Float32,2}, ::ForwardDiff.Partials{12,Any}) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:18
[3] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:55 [inlined]
[4] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:62 [inlined]
[5] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:68 [inlined]
[6] (::Zygote.var"#1565#1567"{12,Int64})(::Array{Float32,2}, ::Int64) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:8
[7] (::Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}})(::Tuple{Array{Float32,2},Int64}) at ./generator.jl:36
[8] iterate at ./generator.jl:47 [inlined]
[9] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Params,UnitRange{Int64}}},Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}}}) at ./array.jl:622
[10] map at ./abstractarray.jl:2155 [inlined]
[11] seed at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:7 [inlined] (repeats 2 times)
[12] forward_jacobian(::var"#340#342"{var"#339#341"{var"#343#344"{DiagGaussianPolicy{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1}},
Array{Float32,1},Array{Float32,1}}},Array{Float32,1}}, ::Params, ::Val{12}) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:23
[13] forward_jacobian(::Function, ::Params) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:40
[14] hessian_vector_product(::Function, ::Params, ::Array{Float32,1}) at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:52
[15] test() at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:64
[16] top-level scope at REPL[41]:1