Skip to content

Commit d8fea6e

Browse files
committed
Add background sampling handle and docs for pm.sample
1 parent 05c9332 commit d8fea6e

File tree

2 files changed

+125
-54
lines changed

2 files changed

+125
-54
lines changed

pymc/sampling/mcmc.py

Lines changed: 92 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import sys
2121
import time
2222
import warnings
23+
import threading
2324

2425
from collections.abc import Callable, Iterator, Mapping, Sequence
2526
from typing import (
@@ -91,6 +92,44 @@
9192
Step: TypeAlias = BlockedStep | CompoundStep
9293

9394

95+
class BackgroundSampleHandle:
96+
def __init__(self, target, args=None, kwargs=None):
97+
self._done = threading.Event()
98+
self._result = None
99+
self._exception = None
100+
args = args or ()
101+
kwargs = kwargs or {}
102+
103+
def runner():
104+
try:
105+
self._result = target(*args, **kwargs)
106+
except Exception as exc: # noqa: BLE001
107+
self._exception = exc
108+
finally:
109+
self._done.set()
110+
111+
self._thread = threading.Thread(target=runner, daemon=True)
112+
113+
def start(self):
114+
self._thread.start()
115+
return self
116+
117+
def done(self):
118+
return self._done.is_set()
119+
120+
def result(self, timeout=None):
121+
self._thread.join(timeout=timeout)
122+
if not self._done.is_set():
123+
raise TimeoutError("Background sampling not finished yet")
124+
if self._exception:
125+
raise self._exception
126+
return self._result
127+
128+
def exception(self, timeout=None):
129+
self._thread.join(timeout=timeout)
130+
return self._exception
131+
132+
94133
class SamplingIteratorCallback(Protocol):
95134
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
96135

@@ -439,6 +478,7 @@ def sample(
439478
mp_ctx=None,
440479
blas_cores: int | None | Literal["auto"] = "auto",
441480
compile_kwargs: dict | None = None,
481+
background: bool = False,
442482
**kwargs,
443483
) -> InferenceData: ...
444484

@@ -472,6 +512,7 @@ def sample(
472512
model: Model | None = None,
473513
blas_cores: int | None | Literal["auto"] = "auto",
474514
compile_kwargs: dict | None = None,
515+
background: bool = False,
475516
**kwargs,
476517
) -> MultiTrace: ...
477518

@@ -504,6 +545,8 @@ def sample(
504545
blas_cores: int | None | Literal["auto"] = "auto",
505546
model: Model | None = None,
506547
compile_kwargs: dict | None = None,
548+
background: bool = False,
549+
_background_internal: bool = False,
507550
**kwargs,
508551
) -> InferenceData | MultiTrace | ZarrTrace:
509552
r"""Draw samples from the posterior using the given step methods.
@@ -540,7 +583,7 @@ def sample(
540583
- "combined": A single progress bar that displays the total progress across all chains. Only timing
541584
information is shown.
542585
- "split": A separate progress bar for each chain. Only timing information is shown.
543-
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
586+
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across
544587
chains. Aggregate sample statistics are also displayed.
545588
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
546589
are also displayed.
@@ -618,7 +661,14 @@ def sample(
618661
Model to sample from. The model needs to have free random variables.
619662
compile_kwargs: dict, optional
620663
Dictionary with keyword argument to pass to the functions compiled by the step methods.
664+
You can find a full list of arguments in the docstring of the step methods.
621665
666+
Background mode
667+
----------------
668+
- Set ``background=True`` to run sampling in a background thread; this returns a handle.
669+
- The handle supports ``done()``, ``result()``, and ``exception()``.
670+
- Progress bars are suppressed in background mode.
671+
- Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.
622672
623673
Returns
624674
-------
@@ -629,65 +679,53 @@ def sample(
629679
``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
630680
the benefits this backend provides.
631681
632-
Notes
633-
-----
634-
Optional keyword arguments can be passed to ``sample`` to be delivered to the
635-
``step_method``\ s used during sampling.
636-
637-
For example:
638-
639-
1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
640-
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
641-
642-
Note that available step names are:
643-
644-
``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
645-
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
646-
``DEMetropolis``, ``DEMetropolisZ``, ``slice``
647-
648-
The NUTS step method has several options including:
649-
650-
* target_accept : float in [0, 1]. The step size is tuned such that we
651-
approximate this acceptance rate. Higher values like 0.9 or 0.95 often
652-
work better for problematic posteriors. This argument can be passed directly to sample.
653-
* max_treedepth : The maximum depth of the trajectory tree
654-
* step_scale : float, default 0.25
655-
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
656-
where n is the dimensionality of the parameter space
657-
658-
Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
659-
kwarg, then you can address the ``step_method`` kwargs directly.
660-
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
661-
you could send ::
662-
663-
step = [
664-
pm.NUTS([freeRV1, freeRV2], target_accept=0.9),
665-
pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7),
666-
]
667-
668-
You can find a full list of arguments in the docstring of the step methods.
669-
670682
Examples
671683
--------
672684
.. code-block:: ipython
685+
"""
686+
if background and not _background_internal:
687+
if nuts_sampler != "pymc":
688+
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
689+
progressbar = False
673690

674-
In [1]: import pymc as pm
675-
...: n = 100
676-
...: h = 61
677-
...: alpha = 2
678-
...: beta = 2
691+
# Resolve the model now so the background thread has a concrete model object.
692+
resolved_model = modelcontext(model)
679693

680-
In [2]: with pm.Model() as model: # context management
681-
...: p = pm.Beta("p", alpha=alpha, beta=beta)
682-
...: y = pm.Binomial("y", n=n, p=p, observed=h)
683-
...: idata = pm.sample()
694+
def _run():
695+
return sample(
696+
draws=draws,
697+
tune=tune,
698+
chains=chains,
699+
cores=cores,
700+
random_seed=random_seed,
701+
progressbar=progressbar,
702+
progressbar_theme=progressbar_theme,
703+
step=step,
704+
var_names=var_names,
705+
nuts_sampler=nuts_sampler,
706+
initvals=initvals,
707+
init=init,
708+
jitter_max_retries=jitter_max_retries,
709+
n_init=n_init,
710+
trace=trace,
711+
discard_tuned_samples=discard_tuned_samples,
712+
compute_convergence_checks=compute_convergence_checks,
713+
keep_warning_stat=keep_warning_stat,
714+
return_inferencedata=return_inferencedata,
715+
idata_kwargs=idata_kwargs,
716+
nuts_sampler_kwargs=nuts_sampler_kwargs,
717+
callback=callback,
718+
mp_ctx=mp_ctx,
719+
blas_cores=blas_cores,
720+
model=resolved_model,
721+
compile_kwargs=compile_kwargs,
722+
background=False,
723+
_background_internal=True,
724+
**kwargs,
725+
)
684726

685-
In [3]: az.summary(idata, kind="stats")
727+
return BackgroundSampleHandle(target=_run).start()
686728

687-
Out[3]:
688-
mean sd hdi_3% hdi_97%
689-
p 0.609 0.047 0.528 0.699
690-
"""
691729
if "start" in kwargs:
692730
if initvals is not None:
693731
raise ValueError("Passing both `start` and `initvals` is not supported.")
@@ -1735,4 +1773,4 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
17351773
for initial_point in initial_points
17361774
]
17371775

1738-
return initial_points, step
1776+
return initial_points, step
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pymc as pm
2+
import pytest
3+
4+
5+
def test_background_sampling_happy_path():
6+
with pm.Model():
7+
pm.Normal("x", 0, 1)
8+
handle = pm.sample(
9+
draws=20,
10+
tune=10,
11+
chains=1,
12+
cores=1,
13+
background=True,
14+
progressbar=False,
15+
)
16+
idata = handle.result()
17+
assert hasattr(idata, "posterior")
18+
assert idata.posterior.sizes["chain"] >= 1
19+
20+
21+
def test_background_sampling_raises():
22+
with pm.Model():
23+
pm.Normal("x", 0, sigma=-1)
24+
handle = pm.sample(
25+
draws=10,
26+
tune=5,
27+
chains=1,
28+
cores=1,
29+
background=True,
30+
progressbar=False,
31+
)
32+
with pytest.raises(Exception):
33+
handle.result()

0 commit comments

Comments
 (0)