feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel#2124
feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel#2124juanitorduz wants to merge 11 commits into
Conversation
Add MCLMC inference algorithm as a new MCMCKernel that wraps blackjax's MCLMC implementation. This provides an alternative gradient-based MCMC method to NUTS/HMC. Features: - MCLMC kernel with automatic step size and trajectory length tuning - Optional blackjax dependency with informative error message - postprocess_fn for constrained/unconstrained transformations - Diagnostics string for progress bar - Comprehensive test suite References: - Microcanonical Hamiltonian Monte Carlo (arXiv:2212.08549)
|
Hey @reubenharry I tried using your branch, but I made a git mess 🙈 so I decided to open another one to get feedback (I will add you as a coauthor). Could you please check this one out and see if the implementation (and the tests) are as expected? |
fehiepsi
left a comment
There was a problem hiding this comment.
I worry that we might introduce technical debt by depending on other libraries. Could we turn this into an example/tutorial instead? It's not clear to me the benefit of using numpyro here.
|
Maybe adding a section for https://num.pyro.ai/en/stable/tutorials/other_samplers.html would be better? edit: #2035 has good discussion about the above points |
Sure! This was a first attempt at trying to see how it would fit. I agree this "optional" dependencies can be hard to maintain. So what about adding another section with this code in https://num.pyro.ai/en/stable/tutorials/other_samplers.html ? Or do we want an additional notebook with just a brief explanation of MCLMC? |
|
on the other hand the optimal outcome would be that its very easy for numpyro users to use the sampler. if it's hidden in some tutorial... |
True 😄 . I do not have any strong opinion. I just wanted to bring this one alive. That being said, I have seen blackjax making breaking changes and this would be a pain to maintain indeed (in the notebook we would put a disclaimer about these potential changes) |
|
how entangled is the blackjax implementation with the rest of blackjax? can the sampler be ripped out with minimal changes? |
|
Awesome, thanks for doing this! It looks good to me, although for peace of mind, we should take a non-trivial example and check that the efficiency is good (e.g. Stochastic Volatility or something). Perhaps I can do that by running the Numpyro version in my https://github.com/reubenharry/sampler-benchmarks repo and checking the results, when I have more time. @martinjankowiak Re entanglement, I have previously ported the implementation outside Blackjax before (even out of Jax), and it isn't insanely hard to do so (blackjax itself isn't really a very complex codebase). More a question of time. I agree that the main draw of adding the code to Numpyro is discoverability for users. (I have been a bit busy this month, but will try to be more responsive going forward) |
|
This implementation works for me. I'd say that if it was totally necessary, the relevant parts of blackjax could be pulled out and the implementation could be self-contained. But I think it would be more natural to keep it in - I think changes that break backwards compatibility of this sampler are unlikely to be frequent. |
|
@fehiepsi @martinjankowiak @reubenharry Suggestion: Keep this code but put it into numpyro/numpyro/contrib/nested_sampling.py Lines 10 to 29 in d5598e7 |
|
Hi @juanitorduz, I planned to remove jaxns wrapper at some point to avoid maintenance overhead. So I hope we can avoid adding wrappers for other libraries. |
|
it seems to me like the best option is to pull out the implementation from blackjax. this is something claude could presumably do in < 10 minutes. the only downside is that you don't automatically inherit any improvements made to the blackjax implementation, but that's not a big deal and can always be addressed with future PRs if there are significant improvements in the blackjax implementation. |
|
ok! Fair enough! Then I will proceed to pull out the implementation from blackjax :) 👍 |
Resolve modify/delete conflict on setup.py: accept master's deletion (migrated to pyproject.toml) and port blackjax>=1.3 test dependency to pyproject.toml. Made-with: Cursor
|
@reubenharry I have tried to remove the blackjax dependency with Claude's help! I ensured the results match the code from #2039 but it would be great if you could give a proper review @fehiepsi @martinjankowiak Open for feedback :) |
| @@ -0,0 +1,859 @@ | |||
| # Copyright Contributors to the Pyro project. | |||
There was a problem hiding this comment.
presumably many of the private methods in this file have corresponding tests in blackjax that you can pull into this PR?
hey folks! any thoughts on this one? This raw implementation is self-contained and about 859 lines of code (I am fine with it) I wanna also want to bring up the fact that blackjax is quite active now https://github.com/blackjax-devs/blackjax/releases, so my gut tells me it's less risky to have it as an optional dependency (much safer than jaxns). |
|
Sorry for the slow reply. I had one minor comment (see above), but since this follows the BlackJax quite closely, the rest looks fine. I think either the self contained version, or using Blackjax directly would be fine (but if the latter, I'd suggest that NumPyro makes a larger scale decision to be open to using Blackjax across the board). |
Trying to support #2039