[Feature] Adding blackjax sampling functionality#367
[Feature] Adding blackjax sampling functionality#367vschutze-alt merged 31 commits intoHEP-PBSP:mainfrom
Conversation
40e4f8f to
d370250
Compare
4f963d1 to
2786469
Compare
…rior is now returning a dictionary
cf7e3b4 to
10fe4ee
Compare
pyproject.toml
Outdated
| nnpdf = { git = "https://github.com/NNPDF/nnpdf" } | ||
| anesthetic = "^2.10.2" | ||
| tfp-nightly = { extras = ["jax"], version = "*" } | ||
| blackjax = {git = "https://github.com/handley-lab/blackjax", rev = "nested_sampling"} |
There was a problem hiding this comment.
As we are preparing this I would suggest revisiting if we can switch to the tagged version:
https://github.com/handley-lab/blackjax/releases/tag/nested-sampling-beta.1
There are some changes incoming to this main PR branch as we finalise it that may be breaking, so it would be better if possible to point to a stable release so we can manage the (hopefully) finished merged PR integration with colibri at a later date. I think in theory it should be a like for like swap but I think there were some conda env problems when we last looked at this
There was a problem hiding this comment.
Hi @yallup , thanks for this comment. I have tried again to use that tagged version, but still have the same error when trying to install:
blackjax = {git = "https://github.com/handley-lab/blackjax", tag = "nested-sampling-beta.1" }
assert version is not None ^^^^^^^^^^^^^^^^^^^ AssertionError
I think it has to do with the name of the tag, but I am not sure. Do you have any updates on this? Thanks!
There was a problem hiding this comment.
PS. I have noticed that there is another tagged version in the blackjax repository. I have switched to that one and the installation works fine, so if that version is ok I think we should be ready to merge on the colibri side of things
LucaMantani
left a comment
There was a problem hiding this comment.
I skimmed through it and looks good overall, I have a suggestion on the implementation of the new bayesian prior mostly
Co-authored-by: Luca Mantani <[email protected]>
Adds a JAX-native nested sampling fitter using the
blackjaxlibrary, enabling end-to-end JAX-based Bayesian inference.Summary of current changes:
blackjax_fit.py, which implements the nested sampling loop usingblackjax.nss.bayesian_priorfunction is refactored. It now returns a dictionary containingprior_transform(for UltraNest), andlog_probandsamplefunctions (for BlackJAX).blackjax,tensorflow-probability(for prior distributions), andanesthetic(for results handling). The last two of these are optional but are included into the pyproject.toml for ease of useblackjax_settingsin runcards.lh_fit_closure_test_blackjax.yaml, is included to demonstrate usage.Todo