@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
3131 Returns
3232 -------
3333
34- fig : figure object
34+ fig, ax : tuple of matplotlib figure and axes
3535
3636 """
3737 import matplotlib .pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
6969 pass
7070
7171 plt .tight_layout ()
72- return fig
72+ return ( fig , ax )
7373
7474def histplot_op (ax , data ):
7575 for i in range (data .shape [1 ]):
@@ -131,20 +131,38 @@ def kde2plot_op(ax, x, y, grid=200):
131131def kdeplot (data ):
132132 f , ax = subplots (1 , 1 , squeeze = True )
133133 kdeplot_op (ax , data )
134- return f
134+ return f , ax
135135
136136
137137def kde2plot (x , y , grid = 200 ):
138138 f , ax = subplots (1 , 1 , squeeze = True )
139139 kde2plot_op (ax , x , y , grid )
140- return f
140+ return f , ax
141141
142142
143- def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144- """Bar plot of the autocorrelation function for a trace"""
143+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
144+ """Bar plot of the autocorrelation function for a trace
145+
146+ Parameters
147+ ----------
148+
149+ trace : result of MCMC run
150+ vars : list of variable names
151+ Variables to be plotted, if None all variable are plotted
152+ max_lag: int
153+ Maximum lag to calculate autocorrelation. Defaults to 100.
154+ burn: int
155+ Number of samples to discard from the beginning of the trace.
156+ Defaults to 0.
157+
158+ Returns
159+ -------
160+
161+ fig, ax : tuple of matplotlib figure and axes
162+
163+ """
164+
145165 import matplotlib .pyplot as plt
146- if fontmap is None :
147- fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
148166
149167 if vars is None :
150168 vars = trace .varnames
@@ -153,13 +171,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153171
154172 chains = trace .nchains
155173
156- f , ax = plt .subplots (len (vars ), chains , squeeze = False )
174+ fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
157175
158176 max_lag = min (len (trace ) - 1 , max_lag )
159177
160178 for i , v in enumerate (vars ):
161179 for j in range (chains ):
162- d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
180+ d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
163181
164182 ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
165183
@@ -169,13 +187,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169187
170188 if chains > 1 :
171189 ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
172-
173- # Smaller tick labels
174- tlabels = plt .gca ().get_xticklabels ()
175- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176-
177- tlabels = plt .gca ().get_yticklabels ()
178- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
190+
191+ return (fig , ax )
179192
180193
181194def var_str (name , shape ):
0 commit comments