@@ -82,8 +82,8 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8282
8383
8484def sample (draws , step = None , init = 'advi' , n_init = 200000 , start = None ,
85- trace = None , thin = 1 , burn = 0 , chain = 0 , njobs = 1 , tune = None ,
86- progressbar = True , model = None , random_seed = - 1 ):
85+ trace = None , chain = 0 , njobs = 1 , tune = None , progressbar = True ,
86+ model = None , random_seed = - 1 ):
8787 """
8888 Draw a number of samples using the given step method.
8989 Multiple step methods supported via compound step method
@@ -120,10 +120,6 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
120120 Passing either "text" or "sqlite" is taken as a shortcut to set
121121 up the corresponding backend (with "mcmc" used as the base
122122 name).
123- thin : int
124- Only store every <thin>'th sample.
125- burn : int
126- Do not store <burn> number of first samples.
127123 chain : int
128124 Chain number used to store sample in backend. If `njobs` is
129125 greater than one, chain numbers will start here.
@@ -163,8 +159,6 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
163159 sample_args = {'draws' : draws ,
164160 'step' : step ,
165161 'start' : start ,
166- 'thin' : thin ,
167- 'burn' : burn ,
168162 'trace' : trace ,
169163 'chain' : chain ,
170164 'tune' : tune ,
@@ -181,13 +175,12 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
181175 return sample_func (** sample_args )
182176
183177
184- def _sample (draws , step = None , start = None , thin = 1 , burn = 0 , trace = None ,
185- chain = 0 , tune = None , progressbar = True , model = None ,
186- random_seed = - 1 ):
187- sampling = _iter_sample (draws , step , start , thin , burn , trace ,
188- chain , tune , model , random_seed )
178+ def _sample (draws , step = None , start = None , trace = None , chain = 0 , tune = None ,
179+ progressbar = True , model = None , random_seed = - 1 ):
180+ sampling = _iter_sample (draws , step , start , trace , chain ,
181+ tune , model , random_seed )
189182 if progressbar :
190- sampling = tqdm (sampling , total = round (( draws - burn ) / thin ) )
183+ sampling = tqdm (sampling , total = draws )
191184 try :
192185 for strace in sampling :
193186 pass
@@ -196,8 +189,8 @@ def _sample(draws, step=None, start=None, thin=1, burn=0, trace=None,
196189 return MultiTrace ([strace ])
197190
198191
199- def iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
200- chain = 0 , tune = None , model = None , random_seed = - 1 ):
192+ def iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
193+ model = None , random_seed = - 1 ):
201194 """
202195 Generator that returns a trace on each iteration using the given
203196 step method. Multiple step methods supported via compound step
@@ -211,10 +204,6 @@ def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
211204 The number of samples to draw
212205 step : function
213206 Step function
214- thin : int
215- Only store every <thin>'th sample.
216- burn : int
217- Do not store <burn> number of first samples.
218207 start : dict
219208 Starting point in parameter space (or partial point)
220209 Defaults to trace.point(-1)) if there is a trace provided and
@@ -239,14 +228,14 @@ def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
239228 for trace in iter_sample(500, step):
240229 ...
241230 """
242- sampling = _iter_sample (draws , step , start , thin , burn , trace ,
243- chain , tune , model , random_seed )
231+ sampling = _iter_sample (draws , step , start , trace , chain , tune ,
232+ model , random_seed )
244233 for i , strace in enumerate (sampling ):
245234 yield MultiTrace ([strace [:i + 1 ]])
246235
247236
248- def _iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
249- chain = 0 , tune = None , model = None , random_seed = - 1 ):
237+ def _iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
238+ model = None , random_seed = - 1 ):
250239 model = modelcontext (model )
251240 draws = int (draws )
252241 if random_seed != - 1 :
@@ -276,9 +265,8 @@ def _iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
276265 if i == tune :
277266 step = stop_tuning (step )
278267 point = step .step (point )
279- if (i % thin == 0 ) and (i >= burn ):
280- strace .record (point )
281- yield strace
268+ strace .record (point )
269+ yield strace
282270 else :
283271 strace .close ()
284272
0 commit comments