@@ -82,8 +82,8 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8282
8383
8484def sample (draws , step = None , init = 'advi' , n_init = 200000 , start = None ,
85- trace = None , chain = 0 , njobs = 1 , tune = None , progressbar = True ,
86- model = None , random_seed = - 1 ):
85+ trace = None , thin = 1 , burn = 0 , chain = 0 , njobs = 1 , tune = None ,
86+ progressbar = True , model = None , random_seed = - 1 ):
8787 """
8888 Draw a number of samples using the given step method.
8989 Multiple step methods supported via compound step method
@@ -120,6 +120,10 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
120120 Passing either "text" or "sqlite" is taken as a shortcut to set
121121 up the corresponding backend (with "mcmc" used as the base
122122 name).
123+ thin : int
124+ Only store every <thin>'th sample.
125+ burn : int
126+ Do not store <burn> number of first samples.
123127 chain : int
124128 Chain number used to store sample in backend. If `njobs` is
125129 greater than one, chain numbers will start here.
@@ -159,6 +163,8 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
159163 sample_args = {'draws' : draws ,
160164 'step' : step ,
161165 'start' : start ,
166+ 'thin' : thin ,
167+ 'burn' : burn ,
162168 'trace' : trace ,
163169 'chain' : chain ,
164170 'tune' : tune ,
@@ -175,12 +181,13 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
175181 return sample_func (** sample_args )
176182
177183
178- def _sample (draws , step = None , start = None , trace = None , chain = 0 , tune = None ,
179- progressbar = True , model = None , random_seed = - 1 ):
180- sampling = _iter_sample (draws , step , start , trace , chain ,
181- tune , model , random_seed )
184+ def _sample (draws , step = None , start = None , thin = 1 , burn = 0 , trace = None ,
185+ chain = 0 , tune = None , progressbar = True , model = None ,
186+ random_seed = - 1 ):
187+ sampling = _iter_sample (draws , step , start , thin , burn , trace ,
188+ chain , tune , model , random_seed )
182189 if progressbar :
183- sampling = tqdm (sampling , total = draws )
190+ sampling = tqdm (sampling , total = round (( draws - burn ) / thin ) )
184191 try :
185192 for strace in sampling :
186193 pass
@@ -189,8 +196,8 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
189196 return MultiTrace ([strace ])
190197
191198
192- def iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
193- model = None , random_seed = - 1 ):
199+ def iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
200+ chain = 0 , tune = None , model = None , random_seed = - 1 ):
194201 """
195202 Generator that returns a trace on each iteration using the given
196203 step method. Multiple step methods supported via compound step
@@ -204,6 +211,10 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
204211 The number of samples to draw
205212 step : function
206213 Step function
214+ thin : int
215+ Only store every <thin>'th sample.
216+ burn : int
217+ Do not store <burn> number of first samples.
207218 start : dict
208219 Starting point in parameter space (or partial point)
209220 Defaults to trace.point(-1)) if there is a trace provided and
@@ -228,14 +239,14 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
228239 for trace in iter_sample(500, step):
229240 ...
230241 """
231- sampling = _iter_sample (draws , step , start , trace , chain , tune ,
232- model , random_seed )
242+ sampling = _iter_sample (draws , step , start , thin , burn , trace ,
243+ chain , tune , model , random_seed )
233244 for i , strace in enumerate (sampling ):
234245 yield MultiTrace ([strace [:i + 1 ]])
235246
236247
237- def _iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
238- model = None , random_seed = - 1 ):
248+ def _iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
249+ chain = 0 , tune = None , model = None , random_seed = - 1 ):
239250 model = modelcontext (model )
240251 draws = int (draws )
241252 if random_seed != - 1 :
@@ -265,8 +276,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
265276 if i == tune :
266277 step = stop_tuning (step )
267278 point = step .step (point )
268- strace .record (point )
269- yield strace
279+ if (i % thin == 0 ) and (i >= burn ):
280+ strace .record (point )
281+ yield strace
270282 else :
271283 strace .close ()
272284
0 commit comments