Skip to content

[Feature] Adding blackjax sampling functionality#367

Merged
vschutze-alt merged 31 commits intoHEP-PBSP:mainfrom
yallup:blackjax
Mar 9, 2026
Merged

[Feature] Adding blackjax sampling functionality#367
vschutze-alt merged 31 commits intoHEP-PBSP:mainfrom
yallup:blackjax

Conversation

@yallup
Copy link
Collaborator

@yallup yallup commented Jun 23, 2025

Adds a JAX-native nested sampling fitter using the blackjax library, enabling end-to-end JAX-based Bayesian inference.

Summary of current changes:

  • New Fitter: Introduces blackjax_fit.py, which implements the nested sampling loop using blackjax.nss.
  • Generalized Prior: The bayesian_prior function is refactored. It now returns a dictionary containing prior_transform (for UltraNest), and log_prob and sample functions (for BlackJAX).
  • New Dependencies: Adds blackjax, tensorflow-probability (for prior distributions), and anesthetic (for results handling). The last two of these are optional but are included into the pyproject.toml for ease of use
  • Configuration: The new fitter is integrated into the configuration system, accepting blackjax_settings in runcards.
  • Example Runcard: A new runcard, lh_fit_closure_test_blackjax.yaml, is included to demonstrate usage.

Todo

  • Correct the returned BayesianFit results to be more compatible with colibri standards
  • Test Vectorization of the likelihood, particularly on a GPU
  • Add tests

pyproject.toml Outdated
nnpdf = { git = "https://github.com/NNPDF/nnpdf" }
anesthetic = "^2.10.2"
tfp-nightly = { extras = ["jax"], version = "*" }
blackjax = {git = "https://github.com/handley-lab/blackjax", rev = "nested_sampling"}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are preparing this I would suggest revisiting if we can switch to the tagged version:
https://github.com/handley-lab/blackjax/releases/tag/nested-sampling-beta.1

There are some changes incoming to this main PR branch as we finalise it that may be breaking, so it would be better if possible to point to a stable release so we can manage the (hopefully) finished merged PR integration with colibri at a later date. I think in theory it should be a like for like swap but I think there were some conda env problems when we last looked at this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yallup , thanks for this comment. I have tried again to use that tagged version, but still have the same error when trying to install:

blackjax = {git = "https://github.com/handley-lab/blackjax", tag = "nested-sampling-beta.1" }

assert version is not None ^^^^^^^^^^^^^^^^^^^ AssertionError

I think it has to do with the name of the tag, but I am not sure. Do you have any updates on this? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS. I have noticed that there is another tagged version in the blackjax repository. I have switched to that one and the installation works fine, so if that version is ok I think we should be ready to merge on the colibri side of things

Copy link
Member

@LucaMantani LucaMantani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skimmed through it and looks good overall, I have a suggestion on the implementation of the new bayesian prior mostly

Copy link
Member

@LucaMantani LucaMantani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me!

@vschutze-alt vschutze-alt merged commit 74b9b9f into HEP-PBSP:main Mar 9, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants