77
88
99def traceplot (trace , vars = None , figsize = None ,
10- lines = None , combined = False , grid = True ):
10+ lines = None , combined = False , grid = True , ax = None ):
1111 """Plot samples histograms and values
1212
1313 Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
2727 (default), chains will be plotted separately.
2828 grid : bool
2929 Flag for adding gridlines to histogram. Defaults to True.
30+ ax : axes
31+ Matplotlib axes. Defaults to None.
3032
3133 Returns
3234 -------
3335
34- fig, ax : tuple of matplotlib figure and axes
36+ ax : matplotlib axes
3537
3638 """
3739 import matplotlib .pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
4345 if figsize is None :
4446 figsize = (12 , n * 2 )
4547
46- fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
48+ if ax is None :
49+ fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
50+ elif ax .shape != (n ,2 ):
51+ print ('traceplot requires n*2 subplots' )
52+ return None
4753
4854 for i , v in enumerate (vars ):
4955 for d in trace .get_values (v , combine = combined , squeeze = False ):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
6975 pass
7076
7177 plt .tight_layout ()
72- return ( fig , ax )
78+ return ax
7379
7480def histplot_op (ax , data ):
7581 for i in range (data .shape [1 ]):
@@ -128,19 +134,21 @@ def kde2plot_op(ax, x, y, grid=200):
128134 extent = [xmin , xmax , ymin , ymax ])
129135
130136
131- def kdeplot (data ):
132- f , ax = subplots (1 , 1 , squeeze = True )
137+ def kdeplot (data , ax = None ):
138+ if ax is None :
139+ f , ax = subplots (1 , 1 , squeeze = True )
133140 kdeplot_op (ax , data )
134- return f , ax
141+ return ax
135142
136143
137- def kde2plot (x , y , grid = 200 ):
138- f , ax = subplots (1 , 1 , squeeze = True )
144+ def kde2plot (x , y , grid = 200 , ax = None ):
145+ if ax is None :
146+ f , ax = subplots (1 , 1 , squeeze = True )
139147 kde2plot_op (ax , x , y , grid )
140- return f , ax
148+ return ax
141149
142150
143- def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
151+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 , ax = None ):
144152 """Bar plot of the autocorrelation function for a trace
145153
146154 Parameters
@@ -149,16 +157,18 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
149157 trace : result of MCMC run
150158 vars : list of variable names
151159 Variables to be plotted, if None all variable are plotted
152- max_lag: int
160+ max_lag : int
153161 Maximum lag to calculate autocorrelation. Defaults to 100.
154- burn: int
162+ burn : int
155163 Number of samples to discard from the beginning of the trace.
156164 Defaults to 0.
157-
165+ ax : axes
166+ Matplotlib axes. Defaults to None.
167+
158168 Returns
159169 -------
160170
161- fig, ax : tuple of matplotlib figure and axes
171+ ax : matplotlib axes
162172
163173 """
164174
@@ -213,7 +223,7 @@ def var_str(name, shape):
213223
214224def forestplot (trace_obj , vars = None , alpha = 0.05 , quartiles = True , rhat = True ,
215225 main = None , xtitle = None , xrange = None , ylabels = None ,
216- chain_spacing = 0.05 , vline = 0 ):
226+ chain_spacing = 0.05 , vline = 0 , gs = None ):
217227 """ Forest plot (model summary plot)
218228
219229 Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -258,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
258268
259269 vline (optional): numeric
260270 Location of vertical reference line (defaults to 0).
271+
272+ gs : GridSpec
273+ Matplotlib GridSpec object. Defaults to None.
274+
275+ Returns
276+ -------
277+
278+ gs : matplotlib GridSpec
261279
262280 """
263281 import matplotlib .pyplot as plt
@@ -283,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
283301 # Number of chains
284302 chains = None
285303
286- # Gridspec
287- gs = None
288-
289304 # Subplots
290305 interval_plot = None
291306 rhat_plot = None
0 commit comments