11"""Text file trace backend
22
3- After sampling with NDArray backend, save results as text files.
3+ Store sampling values as CSV files.
44
5- As this other backends, this can be used by passing the backend instance
6- to `sample`.
5+ File format
6+ -----------
77
8- >>> import pymc3 as pm
9- >>> db = pm.backends.Text('test')
10- >>> trace = pm.sample(..., trace=db)
8+ Sampling values for each chain are saved in a separate file (under a
9+ directory specified by the `name` argument). The rows correspond to
10+ sampling iterations. The column names consist of variable names and
11+ index labels. For example, the heading
1112
12- Or sampling can be performed with the default NDArray backend and then
13- dumped to text files after.
13+ x,y__0_0,y__0_1,y__1_0,y__1_1,y__2_0,y__2_1
1414
15- >>> from pymc3.backends import text
16- >>> trace = pm.sample(...)
17- >>> text.dump('test', trace)
18-
19- Database format
20- ---------------
21-
22- For each chain, a directory named `chain-N` is created. In this
23- directory, one file per variable is created containing the values of the
24- object. To deal with multidimensional variables, the array is reshaped
25- to one dimension before saving with `numpy.savetxt`. The shape and dtype
26- information is saved in a json file in the same directory and is used to
27- load the database back again using `numpy.loadtxt`.
15+ represents two variables, x and y, where x is a scalar and y has a
16+ shape of (3, 2).
2817"""
29- import os
30- import glob
31- import json
18+ from glob import glob
3219import numpy as np
20+ import os
21+ import pandas as pd
22+ import warnings
3323
3424from ..backends import base
35- from ..backends .ndarray import NDArray
3625
3726
38- class Text (NDArray ):
39- """Text storage
27+ class Text (base . BaseTrace ):
28+ """Text trace object
4029
4130 Parameters
4231 ----------
@@ -53,102 +42,207 @@ def __init__(self, name, model=None, vars=None):
5342 os .mkdir (name )
5443 super (Text , self ).__init__ (name , model , vars )
5544
56- def close (self ):
57- super (Text , self ).close ()
58- _dump_trace (self .name , self )
59-
60-
61- def dump (name , trace , chains = None ):
62- """Store NDArray trace as text database.
63-
64- Parameters
65- ----------
66- name : str
67- Name of directory to store text files
68- trace : MultiTrace of NDArray traces
69- Result of MCMC run with default NDArray backend
70- chains : list
71- Chains to dump. If None, all chains are dumped.
72- """
73- if not os .path .exists (name ):
74- os .mkdir (name )
75- if chains is None :
76- chains = trace .chains
77- for chain in chains :
78- _dump_trace (name , trace ._traces [chain ])
79-
45+ self .flat_names = {v : _create_flat_names (v , shape )
46+ for v , shape in self .var_shapes .items ()}
47+
48+ self .filename = None
49+ self ._fh = None
50+ self .df = None
51+
52+ ## Sampling methods
53+
54+ def setup (self , draws , chain ):
55+ """Perform chain-specific setup.
56+
57+ Parameters
58+ ----------
59+ draws : int
60+ Expected number of draws
61+ chain : int
62+ Chain number
63+ """
64+ self .chain = chain
65+ self .filename = os .path .join (self .name , 'chain-{}.csv' .format (chain ))
66+
67+ cnames = [fv for v in self .varnames for fv in self .flat_names [v ]]
68+
69+ if os .path .exists (self .filename ):
70+ with open (self .filename ) as fh :
71+ prev_cnames = next (fh ).strip ().split (',' )
72+ if prev_cnames != cnames :
73+ raise base .BackendError (
74+ "Previous file '{}' has different variables names "
75+ "than current model." .format (self .filename ))
76+ self ._fh = open (self .filename , 'a' )
77+ else :
78+ self ._fh = open (self .filename , 'w' )
79+ self ._fh .write (',' .join (cnames ) + '\n ' )
80+
81+ def record (self , point ):
82+ """Record results of a sampling iteration.
83+
84+ Parameters
85+ ----------
86+ point : dict
87+ Values mapped to variable names
88+ """
89+ vals = {}
90+ for varname , value in zip (self .varnames , self .fn (point )):
91+ vals [varname ] = value .ravel ()
92+ columns = [str (val ) for var in self .varnames for val in vals [var ]]
93+ self ._fh .write (',' .join (columns ) + '\n ' )
8094
81- def _dump_trace (name , trace ):
82- """Dump a single-chain trace.
95+ def close (self ):
96+ self ._fh .close ()
97+ self ._fh = None # Avoid serialization issue.
98+
99+ ## Selection methods
100+
101+ def _load_df (self ):
102+ if self .df is None :
103+ self .df = pd .read_csv (self .filename )
104+
105+ def __len__ (self ):
106+ if self .filename is None :
107+ return 0
108+ self ._load_df ()
109+ return self .df .shape [0 ]
110+
111+ def get_values (self , varname , burn = 0 , thin = 1 ):
112+ """Get values from trace.
113+
114+ Parameters
115+ ----------
116+ varname : str
117+ burn : int
118+ thin : int
119+
120+ Returns
121+ -------
122+ A NumPy array
123+ """
124+ self ._load_df ()
125+ var_df = self .df [self .flat_names [varname ]]
126+ shape = (self .df .shape [0 ],) + self .var_shapes [varname ]
127+ vals = var_df .values .ravel ().reshape (shape )
128+ return vals [burn ::thin ]
129+
130+ def _slice (self , idx ):
131+ warnings .warn ('Slice for Text backend has no effect.' )
132+
133+ def point (self , idx ):
134+ """Return dictionary of point values at `idx` for current chain
135+ with variables names as keys.
136+ """
137+ idx = int (idx )
138+ self ._load_df ()
139+ pt = {}
140+ for varname in self .varnames :
141+ vals = self .df [self .flat_names [varname ]].iloc [idx ]
142+ pt [varname ] = vals .reshape (self .var_shapes [varname ])
143+ return pt
144+
145+
146+ def _create_flat_names (varname , shape ):
147+ """Return flat variable names for `varname` of `shape`.
148+
149+ Examples
150+ --------
151+ >>> _create_flat_names('x', (5,))
152+ ['x__0', 'x__1', 'x__2', 'x__3', 'x__4']
153+
154+ >>> _create_flat_names('x', (2, 2))
155+ ['x__0_0', 'x__0_1', 'x__1_0', 'x__1_1']
83156 """
84- chain_name = 'chain-{}' .format (trace .chain )
85- chain_dir = os .path .join (name , chain_name )
86- os .mkdir (chain_dir )
157+ if not shape :
158+ return [varname ]
159+ labels = (np .ravel (xs ).tolist () for xs in np .indices (shape ))
160+ labels = (map (str , xs ) for xs in labels )
161+ return ['{}__{}' .format (varname , '_' .join (idxs )) for idxs in zip (* labels )]
87162
88- info = {}
89- for varname in trace .varnames :
90- data = trace .get_values (varname )
91163
92- if np . issubdtype ( data . dtype , np . int ):
93- fmt = '%i'
94- is_int = True
95- else :
96- fmt = '%g'
97- is_int = False
98- info [ varname ] = { 'shape' : data . shape , 'is_int' : is_int }
164+ def _create_shape ( flat_names ):
165+ "Determine shape from `_create_flat_names` output."
166+ try :
167+ _ , shape_str = flat_names [ - 1 ]. rsplit ( '__' , 1 )
168+ except ValueError :
169+ return ()
170+ return tuple ( int ( i ) + 1 for i in shape_str . split ( '_' ))
99171
100- var_file = os .path .join (chain_dir , varname + '.txt' )
101- np .savetxt (var_file , data .reshape (- 1 , data .size ), fmt = fmt )
102- ## Store shape and dtype information for reloading.
103- info_file = os .path .join (chain_dir , 'info.json' )
104- with open (info_file , 'w' ) as sfh :
105- json .dump (info , sfh )
106172
107-
108- def load (name , chains = None , model = None ):
109- """Load text database.
173+ def load (name , model = None ):
174+ """Load Text database.
110175
111176 Parameters
112177 ----------
113178 name : str
114- Path to root directory for text database
115- chains : list
116- Chains to load. If None, all chains are loaded.
179+ Name of directory with files (one per chain)
117180 model : Model
118181 If None, the model is taken from the `with` context.
119182
120183 Returns
121184 -------
122- ndarray.Trace instance
185+ A MultiTrace instance
123186 """
124- chain_dirs = _get_chain_dirs (name )
125- if chains is None :
126- chains = list (chain_dirs .keys ())
187+ files = glob (os .path .join (name , 'chain-*.csv' ))
127188
128189 traces = []
129- for chain in chains :
130- chain_dir = chain_dirs [chain ]
131- info_file = os .path .join (chain_dir , 'info.json' )
132- with open (info_file , 'r' ) as sfh :
133- info = json .load (sfh )
134- samples = {}
135- for varname , info in info .items ():
136- var_file = os .path .join (chain_dir , varname + '.txt' )
137- dtype = int if info ['is_int' ] else float
138- flat_data = np .loadtxt (var_file , dtype = dtype )
139- samples [varname ] = flat_data .reshape (info ['shape' ])
140- trace = NDArray (model = model )
141- trace .samples = samples
190+ for f in files :
191+ chain = int (os .path .splitext (f )[0 ].rsplit ('-' , 1 )[1 ])
192+ trace = Text (name , model = model )
142193 trace .chain = chain
194+ trace .filename = f
143195 traces .append (trace )
144196 return base .MultiTrace (traces )
145197
146198
147- def _get_chain_dirs (name ):
148- """Return mapping of chain number to directory."""
149- return {_chain_dir_to_chain (chain_dir ): chain_dir
150- for chain_dir in glob .glob (os .path .join (name , 'chain-*' ))}
199+ def dump (name , trace , chains = None ):
200+ """Store values from NDArray trace as CSV files.
201+
202+ Parameters
203+ ----------
204+ name : str
205+ Name of directory to store CSV files in
206+ trace : MultiTrace of NDArray traces
207+ Result of MCMC run with default NDArray backend
208+ chains : list
209+ Chains to dump. If None, all chains are dumped.
210+ """
211+ if not os .path .exists (name ):
212+ os .mkdir (name )
213+ if chains is None :
214+ chains = trace .chains
215+
216+ var_shapes = trace ._traces [chains [0 ]].var_shapes
217+ flat_names = {v : _create_flat_names (v , shape )
218+ for v , shape in var_shapes .items ()}
219+
220+ for chain in chains :
221+ filename = os .path .join (name , 'chain-{}.csv' .format (chain ))
222+ df = _trace_to_df (trace ._traces [chain ], flat_names )
223+ df .to_csv (filename , index = False )
224+
151225
226+ def _trace_to_df (trace , flat_names = None ):
227+ """Convert single-chain trace to Pandas DataFrame.
152228
153- def _chain_dir_to_chain (chain_dir ):
154- return int (os .path .basename (chain_dir ).split ('-' )[1 ])
229+ Parameters
230+ ----------
231+ trace : NDarray trace
232+ flat_names : dict or None
233+ A dictionary that maps each variable name in `trace` to a list
234+ of flat variable names (e.g., ['x__0', 'x__1', ...])
235+ """
236+ if flat_names is None :
237+ flat_names = {v : _create_flat_names (v , shape )
238+ for v , shape in trace .var_shapes .items ()}
239+
240+ var_dfs = []
241+ for varname , shape in trace .var_shapes .items ():
242+ vals = trace [varname ]
243+ if len (shape ) == 1 :
244+ flat_vals = vals
245+ else :
246+ flat_vals = vals .reshape (len (trace ), np .prod (shape ))
247+ var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [varname ]))
248+ return pd .concat (var_dfs , axis = 1 )
0 commit comments