@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
3131 Returns
3232 -------
3333
34- fig, ax : tuple of matplotlib figure and axes
34+ fig : figure object
3535
3636 """
3737 import matplotlib .pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
6969 pass
7070
7171 plt .tight_layout ()
72- return ( fig , ax )
72+ return fig
7373
7474def histplot_op (ax , data ):
7575 for i in range (data .shape [1 ]):
@@ -131,38 +131,20 @@ def kde2plot_op(ax, x, y, grid=200):
131131def kdeplot (data ):
132132 f , ax = subplots (1 , 1 , squeeze = True )
133133 kdeplot_op (ax , data )
134- return f , ax
134+ return f
135135
136136
137137def kde2plot (x , y , grid = 200 ):
138138 f , ax = subplots (1 , 1 , squeeze = True )
139139 kde2plot_op (ax , x , y , grid )
140- return f , ax
140+ return f
141141
142142
143- def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
144- """Bar plot of the autocorrelation function for a trace
145-
146- Parameters
147- ----------
148-
149- trace : result of MCMC run
150- vars : list of variable names
151- Variables to be plotted, if None all variable are plotted
152- max_lag: int
153- Maximum lag to calculate autocorrelation. Defaults to 100.
154- burn: int
155- Number of samples to discard from the beginning of the trace.
156- Defaults to 0.
157-
158- Returns
159- -------
160-
161- fig, ax : tuple of matplotlib figure and axes
162-
163- """
164-
143+ def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144+ """Bar plot of the autocorrelation function for a trace"""
165145 import matplotlib .pyplot as plt
146+ if fontmap is None :
147+ fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
166148
167149 if vars is None :
168150 vars = trace .varnames
@@ -171,13 +153,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
171153
172154 chains = trace .nchains
173155
174- fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
156+ f , ax = plt .subplots (len (vars ), chains , squeeze = False )
175157
176158 max_lag = min (len (trace ) - 1 , max_lag )
177159
178160 for i , v in enumerate (vars ):
179161 for j in range (chains ):
180- d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
162+ d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
181163
182164 ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
183165
@@ -187,8 +169,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
187169
188170 if chains > 1 :
189171 ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
190-
191- return (fig , ax )
172+
173+ # Smaller tick labels
174+ tlabels = plt .gca ().get_xticklabels ()
175+ plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176+
177+ tlabels = plt .gca ().get_yticklabels ()
178+ plt .setp (tlabels , 'fontsize' , fontmap [1 ])
192179
193180
194181def var_str (name , shape ):
0 commit comments