77
88
99def traceplot (trace , vars = None , figsize = None ,
10- lines = None , combined = False , grid = True ):
10+ lines = None , combined = False , grid = True , ax = None ):
1111 """Plot samples histograms and values
1212
1313 Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
2727 (default), chains will be plotted separately.
2828 grid : bool
2929 Flag for adding gridlines to histogram. Defaults to True.
30+ ax : axes
31+ Matplotlib axes. Defaults to None.
3032
3133 Returns
3234 -------
3335
34- fig : figure object
36+ ax : matplotlib axes
3537
3638 """
3739 import matplotlib .pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
4345 if figsize is None :
4446 figsize = (12 , n * 2 )
4547
46- fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
48+ if ax is None :
49+ fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
50+ elif ax .shape != (n ,2 ):
51+ print ('traceplot requires n*2 subplots' )
52+ return None
4753
4854 for i , v in enumerate (vars ):
4955 for d in trace .get_values (v , combine = combined , squeeze = False ):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
6975 pass
7076
7177 plt .tight_layout ()
72- return fig
78+ return ax
7379
7480def histplot_op (ax , data ):
7581 for i in range (data .shape [1 ]):
@@ -128,23 +134,45 @@ def kde2plot_op(ax, x, y, grid=200):
128134 extent = [xmin , xmax , ymin , ymax ])
129135
130136
131- def kdeplot (data ):
132- f , ax = subplots (1 , 1 , squeeze = True )
137+ def kdeplot (data , ax = None ):
138+ if ax is None :
139+ f , ax = subplots (1 , 1 , squeeze = True )
133140 kdeplot_op (ax , data )
134- return f
141+ return ax
135142
136143
137- def kde2plot (x , y , grid = 200 ):
138- f , ax = subplots (1 , 1 , squeeze = True )
144+ def kde2plot (x , y , grid = 200 , ax = None ):
145+ if ax is None :
146+ f , ax = subplots (1 , 1 , squeeze = True )
139147 kde2plot_op (ax , x , y , grid )
140- return f
148+ return ax
141149
142150
143- def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144- """Bar plot of the autocorrelation function for a trace"""
151+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 , ax = None ):
152+ """Bar plot of the autocorrelation function for a trace
153+
154+ Parameters
155+ ----------
156+
157+ trace : result of MCMC run
158+ vars : list of variable names
159+ Variables to be plotted, if None all variable are plotted
160+ max_lag : int
161+ Maximum lag to calculate autocorrelation. Defaults to 100.
162+ burn : int
163+ Number of samples to discard from the beginning of the trace.
164+ Defaults to 0.
165+ ax : axes
166+ Matplotlib axes. Defaults to None.
167+
168+ Returns
169+ -------
170+
171+ ax : matplotlib axes
172+
173+ """
174+
145175 import matplotlib .pyplot as plt
146- if fontmap is None :
147- fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
148176
149177 if vars is None :
150178 vars = trace .varnames
@@ -153,13 +181,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153181
154182 chains = trace .nchains
155183
156- f , ax = plt .subplots (len (vars ), chains , squeeze = False )
184+ fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
157185
158186 max_lag = min (len (trace ) - 1 , max_lag )
159187
160188 for i , v in enumerate (vars ):
161189 for j in range (chains ):
162- d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
190+ d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
163191
164192 ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
165193
@@ -169,13 +197,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169197
170198 if chains > 1 :
171199 ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
172-
173- # Smaller tick labels
174- tlabels = plt .gca ().get_xticklabels ()
175- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176-
177- tlabels = plt .gca ().get_yticklabels ()
178- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
200+
201+ return (fig , ax )
179202
180203
181204def var_str (name , shape ):
@@ -200,7 +223,7 @@ def var_str(name, shape):
200223
201224def forestplot (trace_obj , vars = None , alpha = 0.05 , quartiles = True , rhat = True ,
202225 main = None , xtitle = None , xrange = None , ylabels = None ,
203- chain_spacing = 0.05 , vline = 0 ):
226+ chain_spacing = 0.05 , vline = 0 , gs = None ):
204227 """ Forest plot (model summary plot)
205228
206229 Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -245,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
245268
246269 vline (optional): numeric
247270 Location of vertical reference line (defaults to 0).
271+
272+ gs : GridSpec
273+ Matplotlib GridSpec object. Defaults to None.
274+
275+ Returns
276+ -------
277+
278+ gs : matplotlib GridSpec
248279
249280 """
250281 import matplotlib .pyplot as plt
@@ -270,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
270301 # Number of chains
271302 chains = None
272303
273- # Gridspec
274- gs = None
275-
276304 # Subplots
277305 interval_plot = None
278306 rhat_plot = None
0 commit comments