2020import sys
2121import time
2222import warnings
23+ import threading
2324
2425from collections .abc import Callable , Iterator , Mapping , Sequence
2526from typing import (
9192Step : TypeAlias = BlockedStep | CompoundStep
9293
9394
95+ class BackgroundSampleHandle :
96+ def __init__ (self , target , args = None , kwargs = None ):
97+ self ._done = threading .Event ()
98+ self ._result = None
99+ self ._exception = None
100+ args = args or ()
101+ kwargs = kwargs or {}
102+
103+ def runner ():
104+ try :
105+ self ._result = target (* args , ** kwargs )
106+ except Exception as exc : # noqa: BLE001
107+ self ._exception = exc
108+ finally :
109+ self ._done .set ()
110+
111+ self ._thread = threading .Thread (target = runner , daemon = True )
112+
113+ def start (self ):
114+ self ._thread .start ()
115+ return self
116+
117+ def done (self ):
118+ return self ._done .is_set ()
119+
120+ def result (self , timeout = None ):
121+ self ._thread .join (timeout = timeout )
122+ if not self ._done .is_set ():
123+ raise TimeoutError ("Background sampling not finished yet" )
124+ if self ._exception :
125+ raise self ._exception
126+ return self ._result
127+
128+ def exception (self , timeout = None ):
129+ self ._thread .join (timeout = timeout )
130+ return self ._exception
131+
132+
94133class SamplingIteratorCallback (Protocol ):
95134 """Signature of the callable that may be passed to `pm.sample(callable=...)`."""
96135
@@ -439,6 +478,7 @@ def sample(
439478 mp_ctx = None ,
440479 blas_cores : int | None | Literal ["auto" ] = "auto" ,
441480 compile_kwargs : dict | None = None ,
481+ background : bool = False ,
442482 ** kwargs ,
443483) -> InferenceData : ...
444484
@@ -472,6 +512,7 @@ def sample(
472512 model : Model | None = None ,
473513 blas_cores : int | None | Literal ["auto" ] = "auto" ,
474514 compile_kwargs : dict | None = None ,
515+ background : bool = False ,
475516 ** kwargs ,
476517) -> MultiTrace : ...
477518
@@ -504,6 +545,8 @@ def sample(
504545 blas_cores : int | None | Literal ["auto" ] = "auto" ,
505546 model : Model | None = None ,
506547 compile_kwargs : dict | None = None ,
548+ background : bool = False ,
549+ _background_internal : bool = False ,
507550 ** kwargs ,
508551) -> InferenceData | MultiTrace | ZarrTrace :
509552 r"""Draw samples from the posterior using the given step methods.
@@ -540,7 +583,7 @@ def sample(
540583 - "combined": A single progress bar that displays the total progress across all chains. Only timing
541584 information is shown.
542585 - "split": A separate progress bar for each chain. Only timing information is shown.
543- - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
586+ - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across
544587 chains. Aggregate sample statistics are also displayed.
545588 - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
546589 are also displayed.
@@ -618,7 +661,14 @@ def sample(
618661 Model to sample from. The model needs to have free random variables.
619662 compile_kwargs: dict, optional
620663 Dictionary with keyword argument to pass to the functions compiled by the step methods.
664+ You can find a full list of arguments in the docstring of the step methods.
621665
666+ Background mode
667+ ----------------
668+ - Set ``background=True`` to run sampling in a background thread; this returns a handle.
669+ - The handle supports ``done()``, ``result()``, and ``exception()``.
670+ - Progress bars are suppressed in background mode.
671+ - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.
622672
623673 Returns
624674 -------
@@ -629,65 +679,53 @@ def sample(
629679 ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
630680 the benefits this backend provides.
631681
632- Notes
633- -----
634- Optional keyword arguments can be passed to ``sample`` to be delivered to the
635- ``step_method``\ s used during sampling.
636-
637- For example:
638-
639- 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
640- 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
641-
642- Note that available step names are:
643-
644- ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
645- ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
646- ``DEMetropolis``, ``DEMetropolisZ``, ``slice``
647-
648- The NUTS step method has several options including:
649-
650- * target_accept : float in [0, 1]. The step size is tuned such that we
651- approximate this acceptance rate. Higher values like 0.9 or 0.95 often
652- work better for problematic posteriors. This argument can be passed directly to sample.
653- * max_treedepth : The maximum depth of the trajectory tree
654- * step_scale : float, default 0.25
655- The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
656- where n is the dimensionality of the parameter space
657-
658- Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
659- kwarg, then you can address the ``step_method`` kwargs directly.
660- e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
661- you could send ::
662-
663- step = [
664- pm.NUTS([freeRV1, freeRV2], target_accept=0.9),
665- pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7),
666- ]
667-
668- You can find a full list of arguments in the docstring of the step methods.
669-
670682 Examples
671683 --------
672684 .. code-block:: ipython
685+ """
686+ if background and not _background_internal :
687+ if nuts_sampler != "pymc" :
688+ raise NotImplementedError ("background=True currently supports nuts_sampler='pymc' only" )
689+ progressbar = False
673690
674- In [1]: import pymc as pm
675- ...: n = 100
676- ...: h = 61
677- ...: alpha = 2
678- ...: beta = 2
691+ # Resolve the model now so the background thread has a concrete model object.
692+ resolved_model = modelcontext (model )
679693
680- In [2]: with pm.Model() as model: # context management
681- ...: p = pm.Beta("p", alpha=alpha, beta=beta)
682- ...: y = pm.Binomial("y", n=n, p=p, observed=h)
683- ...: idata = pm.sample()
694+ def _run ():
695+ return sample (
696+ draws = draws ,
697+ tune = tune ,
698+ chains = chains ,
699+ cores = cores ,
700+ random_seed = random_seed ,
701+ progressbar = progressbar ,
702+ progressbar_theme = progressbar_theme ,
703+ step = step ,
704+ var_names = var_names ,
705+ nuts_sampler = nuts_sampler ,
706+ initvals = initvals ,
707+ init = init ,
708+ jitter_max_retries = jitter_max_retries ,
709+ n_init = n_init ,
710+ trace = trace ,
711+ discard_tuned_samples = discard_tuned_samples ,
712+ compute_convergence_checks = compute_convergence_checks ,
713+ keep_warning_stat = keep_warning_stat ,
714+ return_inferencedata = return_inferencedata ,
715+ idata_kwargs = idata_kwargs ,
716+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
717+ callback = callback ,
718+ mp_ctx = mp_ctx ,
719+ blas_cores = blas_cores ,
720+ model = resolved_model ,
721+ compile_kwargs = compile_kwargs ,
722+ background = False ,
723+ _background_internal = True ,
724+ ** kwargs ,
725+ )
684726
685- In [3]: az.summary(idata, kind="stats" )
727+ return BackgroundSampleHandle ( target = _run ). start ( )
686728
687- Out[3]:
688- mean sd hdi_3% hdi_97%
689- p 0.609 0.047 0.528 0.699
690- """
691729 if "start" in kwargs :
692730 if initvals is not None :
693731 raise ValueError ("Passing both `start` and `initvals` is not supported." )
@@ -1735,4 +1773,4 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
17351773 for initial_point in initial_points
17361774 ]
17371775
1738- return initial_points , step
1776+ return initial_points , step
0 commit comments