2424from collections import defaultdict
2525from copy import copy
2626from typing import (
27+ Any ,
28+ Callable ,
2729 Dict ,
2830 Iterable ,
2931 Iterator ,
@@ -811,12 +813,16 @@ def _sample(
811813
812814 trace = copy (trace )
813815
814- sampling = _iter_sample (draws , step , start , trace , chain , tune , model , random_seed , callback )
816+ sampling_gen = _iter_sample (
817+ draws , step , start , trace , chain , tune , model , random_seed , callback
818+ )
815819 _pbar_data = {"chain" : chain , "divergences" : 0 }
816820 _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
817821 if progressbar :
818- sampling = progress_bar (sampling , total = draws , display = progressbar )
822+ sampling = progress_bar (sampling_gen , total = draws , display = progressbar )
819823 sampling .comment = _desc .format (** _pbar_data )
824+ else :
825+ sampling = sampling_gen
820826 try :
821827 strace = None
822828 for it , (strace , diverging ) in enumerate (sampling ):
@@ -826,6 +832,8 @@ def _sample(
826832 sampling .comment = _desc .format (** _pbar_data )
827833 except KeyboardInterrupt :
828834 pass
835+ if strace is None :
836+ raise Exception ("KeyboardInterrupt happened before the base trace was created." )
829837 return strace
830838
831839
@@ -1494,10 +1502,12 @@ def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTra
14941502 idxs = np .argsort (lengths )
14951503 l_sort = np .array (lengths )[idxs ]
14961504
1497- use_until = np .argmax (l_sort * np .arange (1 , l_sort .shape [0 ] + 1 )[::- 1 ])
1505+ use_until = cast ( int , np .argmax (l_sort * np .arange (1 , l_sort .shape [0 ] + 1 )[::- 1 ]) )
14981506 final_length = l_sort [use_until ]
14991507
1500- return [traces [idx ] for idx in idxs [use_until :]], final_length + tune
1508+ take_idx = cast (Sequence [int ], idxs [use_until :])
1509+ sliced_traces = [traces [idx ] for idx in take_idx ]
1510+ return sliced_traces , final_length + tune
15011511
15021512
15031513def stop_tuning (step ):
@@ -1590,30 +1600,30 @@ def sample_posterior_predictive(
15901600 """
15911601
15921602 _trace : Union [MultiTrace , PointList ]
1603+ nchain : int
15931604 if isinstance (trace , InferenceData ):
15941605 _trace = dataset_to_point_list (trace .posterior )
1606+ nchain , len_trace = chains_and_samples (trace )
15951607 elif isinstance (trace , xarray .Dataset ):
15961608 _trace = dataset_to_point_list (trace )
1597- else :
1609+ nchain , len_trace = chains_and_samples (trace )
1610+ elif isinstance (trace , MultiTrace ):
15981611 _trace = trace
1612+ nchain = _trace .nchains
1613+ len_trace = len (_trace )
1614+ elif isinstance (trace , list ) and all (isinstance (x , dict ) for x in trace ):
1615+ _trace = trace
1616+ nchain = 1
1617+ len_trace = len (_trace )
1618+ else :
1619+ raise TypeError (f"Unsupported type for `trace` argument: { type (trace )} ." )
15991620
16001621 if keep_size is None :
16011622 # This will allow users to set return_inferencedata=False and
16021623 # automatically get the old behaviour instead of needing to
16031624 # set both return_inferencedata and keep_size to False
16041625 keep_size = return_inferencedata
16051626
1606- nchain : int
1607- len_trace : int
1608- if isinstance (trace , (InferenceData , xarray .Dataset )):
1609- nchain , len_trace = chains_and_samples (trace )
1610- else :
1611- len_trace = len (_trace )
1612- try :
1613- nchain = _trace .nchains
1614- except AttributeError :
1615- nchain = 1
1616-
16171627 if keep_size and samples is not None :
16181628 raise IncorrectArgumentsError (
16191629 "Should not specify both keep_size and samples arguments. "
@@ -1625,7 +1635,7 @@ def sample_posterior_predictive(
16251635 if samples is None :
16261636 if isinstance (_trace , MultiTrace ):
16271637 samples = sum (len (v ) for v in _trace ._straces .values ())
1628- elif isinstance (_trace , list ) and all ( isinstance ( x , dict ) for x in _trace ) :
1638+ elif isinstance (_trace , list ):
16291639 # this is a list of points
16301640 samples = len (_trace )
16311641 else :
@@ -1693,6 +1703,7 @@ def sample_posterior_predictive(
16931703 else :
16941704 inputs , input_names = [], []
16951705 else :
1706+ assert isinstance (_trace , MultiTrace )
16961707 output_names = [v .name for v in vars_to_sample if v .name is not None ]
16971708 input_names = [
16981709 n
@@ -1715,7 +1726,7 @@ def sample_posterior_predictive(
17151726
17161727 ppc_trace_t = _DefaultTrace (samples )
17171728 try :
1718- if hasattr (_trace , "_straces" ):
1729+ if isinstance (_trace , MultiTrace ):
17191730 # trace dict is unordered, but we want to return ppc samples in
17201731 # a predictable ordering, so sort the chain indices
17211732 chain_idx_mapping = sorted (_trace ._straces .keys ())
@@ -1750,7 +1761,7 @@ def sample_posterior_predictive(
17501761
17511762 if not return_inferencedata :
17521763 return ppc_trace
1753- ikwargs = dict (model = model )
1764+ ikwargs : Dict [ str , Any ] = dict (model = model )
17541765 if idata_kwargs :
17551766 ikwargs .update (idata_kwargs )
17561767 if predictions :
@@ -1881,8 +1892,8 @@ def sample_posterior_predictive_w(
18811892 indices = np .random .randint (0 , nchain * len_trace , j )
18821893 if nchain > 1 :
18831894 chain_idx , point_idx = np .divmod (indices , len_trace )
1884- for idx in zip (chain_idx , point_idx ):
1885- trace .append (tr ._straces [idx [ 0 ]] .point (idx [ 1 ] ))
1895+ for cidx , pidx in zip (chain_idx , point_idx ):
1896+ trace .append (tr ._straces [cidx ] .point (pidx ))
18861897 else :
18871898 for idx in indices :
18881899 trace .append (tr [idx ])
@@ -1892,12 +1903,12 @@ def sample_posterior_predictive_w(
18921903
18931904 lengths = list ({np .atleast_1d (observed ).shape for observed in obs })
18941905
1906+ size : List [Optional [Tuple [int , ...]]] = []
18951907 if len (lengths ) == 1 :
1896- size = [None for i in variables ]
1908+ size = [None ] * len ( variables )
18971909 elif len (lengths ) > 2 :
18981910 raise ValueError ("Observed variables could not be broadcast together" )
18991911 else :
1900- size = []
19011912 x = np .zeros (shape = lengths [0 ])
19021913 y = np .zeros (shape = lengths [1 ])
19031914 b = np .broadcast (x , y )
@@ -1919,7 +1930,7 @@ def sample_posterior_predictive_w(
19191930 indices = progress_bar (indices , total = samples , display = progressbar )
19201931
19211932 try :
1922- ppc = defaultdict (list )
1933+ ppcl : Dict [ str , list ] = defaultdict (list )
19231934 for idx in indices :
19241935 param = trace [idx ]
19251936 var = variables [idx ]
@@ -1932,13 +1943,13 @@ def sample_posterior_predictive_w(
19321943 except KeyboardInterrupt :
19331944 pass
19341945 else :
1935- ppc = {k : np .asarray (v ) for k , v in ppc .items ()}
1946+ ppcd = {k : np .asarray (v ) for k , v in ppcl .items ()}
19361947 if not return_inferencedata :
1937- return ppc
1938- ikwargs = dict (model = models )
1948+ return ppcd
1949+ ikwargs : Dict [ str , Any ] = dict (model = models )
19391950 if idata_kwargs :
19401951 ikwargs .update (idata_kwargs )
1941- return pm .to_inference_data (posterior_predictive = ppc , ** ikwargs )
1952+ return pm .to_inference_data (posterior_predictive = ppcd , ** ikwargs )
19421953
19431954
19441955def sample_prior_predictive (
@@ -2044,7 +2055,7 @@ def sample_prior_predictive(
20442055
20452056 if not return_inferencedata :
20462057 return prior
2047- ikwargs = dict (model = model )
2058+ ikwargs : Dict [ str , Any ] = dict (model = model )
20482059 if idata_kwargs :
20492060 ikwargs .update (idata_kwargs )
20502061 return pm .to_inference_data (prior = prior , ** ikwargs )
@@ -2106,10 +2117,11 @@ def draw(
21062117
21072118 # Single variable output
21082119 if not isinstance (vars , (list , tuple )):
2109- drawn_values = ( draw_fn () for _ in range ( draws ) )
2110- return np .stack (drawn_values )
2120+ cast ( Callable [[], np . ndarray ], draw_fn )
2121+ return np .stack ([ draw_fn () for _ in range ( draws )] )
21112122
21122123 # Multiple variable output
2124+ cast (Callable [[], List [np .ndarray ]], draw_fn )
21132125 drawn_values = zip (* (draw_fn () for _ in range (draws )))
21142126 return [np .stack (v ) for v in drawn_values ]
21152127
@@ -2120,7 +2132,7 @@ def _init_jitter(
21202132 seeds : Sequence [int ],
21212133 jitter : bool ,
21222134 jitter_max_retries : int ,
2123- ) -> PointType :
2135+ ) -> List [ PointType ] :
21242136 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
21252137
21262138 ``model.check_start_vals`` is used to test whether the jittered starting
@@ -2144,7 +2156,7 @@ def _init_jitter(
21442156 ipfns = make_initial_point_fns_per_chain (
21452157 model = model ,
21462158 overrides = initvals ,
2147- jitter_rvs = set (model .free_RVs ) if jitter else {} ,
2159+ jitter_rvs = set (model .free_RVs ) if jitter else set () ,
21482160 chains = len (seeds ),
21492161 )
21502162
@@ -2282,6 +2294,7 @@ def init_nuts(
22822294
22832295 apoints = [DictToArrayBijection .map (point ) for point in initial_points ]
22842296 apoints_data = [apoint .data for apoint in apoints ]
2297+ potential : quadpotential .QuadPotential
22852298
22862299 if init == "adapt_diag" :
22872300 mean = np .mean (apoints_data , axis = 0 )
0 commit comments