-
Notifications
You must be signed in to change notification settings - Fork 228
Extension for MarginalLogDensities.jl, take 2 #2664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 17257342637Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2664 +/- ##
===========================================
- Coverage 85.70% 58.22% -27.49%
===========================================
Files 22 24 +2
Lines 1434 1441 +7
===========================================
- Hits 1229 839 -390
- Misses 205 602 +397 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Just some thoughts (no actual code).
ext/TuringMarginalLogDensitiesExt.jl
Outdated
f = Turing.Optimisation.OptimLogDensity( | ||
model, | ||
Turing.DynamicPPL.getlogjoint, | ||
# Turing.DynamicPPL.typed_varinfo(model) | ||
varinfo_linked | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making a note here that I am pretty sure that there's a double negative sign somewhere here with OptimLogDensity
and the FlipSign
thing, so I think we will be able to simplify this. (Pre-DynamicPPL 0.37, it used to be that OptimLogDensity
was the only way to get the log-joint without the Jacobian term, so I'm guessing that's why it got used here, but now we can do that more easily with getlogjoint
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still needed, getlogjoint
appears to return a negative log-density:
@model function demo()
x ~ Normal()
end
model = demo()
f = Turing.Optimisation.OptimLogDensity(
model,
DynamicPPL.getlogjoint,
DynamicPPL.typed_varinfo(model)
)
f([0]) # 0.9189385332046728
logpdf(Normal(), 0) #-0.9189385332046728
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, but that's the fault of OptimLogDensity not getlogjoint :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not asking you to fix it btw- happy to do it when I'm back to writing code!
src/extensions.jl
Outdated
function marginalize(model, varnames, method) | ||
error("This function is available after importing MarginalLogDensities.") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked through MLD's and Turing's deps, and the only new dep would be HCubature which is a pretty small dep. So I might be in favour to just making this part of the source code itself instead of an extension. (I think Julia extensions are a really good thing in general but there are a few downsides to them, like having to do this kind of 'define in main package and extend in extension' stuff.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this one's up to you guys. I don't mind either way.
ext/TuringMarginalLogDensitiesExt.jl
Outdated
mdl = MarginalLogDensities.MarginalLogDensity( | ||
Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method | ||
) | ||
return mdl | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From reading the MLD readme, I gather that you can do a couple of things with this:
- Evaluate the marginal log density by calling it
- Perform optimisation to get point estimates
Would it also be useful to be able to perform MCMC sampling with this? In principle it's not too difficult: a function that takes a vector of parameters, and returns a float, is pretty much what NUTS needs.
To do it properly will probably take a bit of time: I think we would have to either change the interface of DynamicPPL.LogDensityFunction
or (perhaps easier) do a wrapper around it that keeps track of which variables are marginalised out. So it wouldn't have to be part of this PR. But if this is worth doing then I could have a think about how to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should already work in theory, since MLD objects implement the LogDensityProblems interface (here). Would be good to add a test for it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, I thought it was "just working" but it isn't. Shouldn't be too hard to with some kind of simple wrapper though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yuppp, the samplers right now only work with DynamicPPL.Model. We should broaden those types and make an interface but it's quite a large undertaking.
Do you know if AD would be able to differentiate the marginal log-density?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet, it's the last major feature I need to add to MLD. It's a bit tricky since you need to differentiate through the Hessian of an optimization result, but it should be possible.
Opening a new version of #2421, as discussed in #2662. Adds an extension for MarginalLogDensities.jl, along with a new exported
marginalize
function.