77__all__ = ['trace_to_dataframe' ]
88
99
10- def trace_to_dataframe (trace , chains = None , flat_names = None , hide_transformed_vars = True ):
10+ def trace_to_dataframe (trace , chains = None , varnames = None , hide_transformed_vars = True ):
1111 """Convert trace to Pandas DataFrame.
1212
1313 Parameters
@@ -16,21 +16,29 @@ def trace_to_dataframe(trace, chains=None, flat_names=None, hide_transformed_var
1616 chains : int or list of ints
1717 Chains to include. If None, all chains are used. A single
1818 chain value can also be given.
19- flat_names : dict or None
20- A dictionary that maps each variable name in `trace` to a list
19+ varnames : list of variable names
20+ Variables to be included in the DataFrame, if None all variable are
21+ included.
22+ hide_transformed_vars: boolean
23+ If true transformed variables will not be included in the resulting
24+ DataFrame.
2125 """
2226 var_shapes = trace ._straces [0 ].var_shapes
23- if flat_names is None :
24- flat_names = {v : create_flat_names (v , shape )
25- for v , shape in var_shapes .items ()
26- if not (hide_transformed_vars and v .endswith ('_' ))}
27+
28+ if varnames is None :
29+ varnames = var_shapes .keys ()
2730
31+ flat_names = {v : create_flat_names (v , shape )
32+ for v , shape in var_shapes .items ()
33+ if not (hide_transformed_vars and v .endswith ('_' ))}
34+
2835 var_dfs = []
29- for varname , shape in var_shapes .items ():
30- if not hide_transformed_vars or not varname .endswith ('_' ):
31- vals = trace .get_values (varname , combine = True , chains = chains )
32- flat_vals = vals .reshape (vals .shape [0 ], - 1 )
33- var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [varname ]))
36+ for v , shape in var_shapes .items ():
37+ if v in varnames :
38+ if not hide_transformed_vars or not v .endswith ('_' ):
39+ vals = trace .get_values (v , combine = True , chains = chains )
40+ flat_vals = vals .reshape (vals .shape [0 ], - 1 )
41+ var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [v ]))
3442 return pd .concat (var_dfs , axis = 1 )
3543
3644
0 commit comments