@@ -191,6 +191,11 @@ def stop_tuning(self):
191191 self .tune = False
192192
193193
194+ def flat_statname (sampler_idx : int , sname : str ) -> str :
195+ """Get the flat-stats name for a samplers stat."""
196+ return f"sampler_{ sampler_idx } __{ sname } "
197+
198+
194199def get_stats_dtypes_shapes_from_steps (
195200 steps : Iterable [BlockedStep ],
196201) -> Dict [str , Tuple [StatDtype , StatShape ]]:
@@ -201,7 +206,7 @@ def get_stats_dtypes_shapes_from_steps(
201206 result = {}
202207 for s , step in enumerate (steps ):
203208 for sname , (dtype , shape ) in step .stats_dtypes_shapes .items ():
204- result [f"sampler_ { s } __ { sname } " ] = (dtype , shape )
209+ result [flat_statname ( s , sname ) ] = (dtype , shape )
205210 return result
206211
207212
@@ -262,10 +267,21 @@ class StatsBijection:
262267
263268 def __init__ (self , sampler_stats_dtypes : Sequence [Mapping [str , type ]]) -> None :
264269 # Keep a list of flat vs. original stat names
265- self ._stat_groups : List [List [Tuple [str , str ]]] = [
266- [(f"sampler_{ s } __{ statname } " , statname ) for statname , _ in names_dtypes .items ()]
267- for s , names_dtypes in enumerate (sampler_stats_dtypes )
268- ]
270+ stat_groups = []
271+ for s , names_dtypes in enumerate (sampler_stats_dtypes ):
272+ group = []
273+ for statname , dtype in names_dtypes .items ():
274+ flatname = flat_statname (s , statname )
275+ is_obj = np .dtype (dtype ) == np .dtype (object )
276+ group .append ((flatname , statname , is_obj ))
277+ stat_groups .append (group )
278+ self ._stat_groups : List [List [Tuple [str , str , bool ]]] = stat_groups
279+ self .object_stats = {
280+ fname : (s , sname )
281+ for s , group in enumerate (self ._stat_groups )
282+ for fname , sname , is_obj in group
283+ if is_obj
284+ }
269285
270286 @property
271287 def n_samplers (self ) -> int :
@@ -275,9 +291,10 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
275291 """Combine stats dicts of multiple samplers into one dict."""
276292 stats_dict = {}
277293 for s , sts in enumerate (stats_list ):
278- for statname , sval in sts .items ():
279- sname = f"sampler_{ s } __{ statname } "
280- stats_dict [sname ] = sval
294+ for fname , sname , is_obj in self ._stat_groups [s ]:
295+ if sname not in sts :
296+ continue
297+ stats_dict [fname ] = sts [sname ]
281298 return stats_dict
282299
283300 def rmap (self , stats_dict : Mapping [str , Any ]) -> StatsType :
@@ -286,7 +303,11 @@ def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
286303 The ``stats_dict`` can be a subset of all sampler stats.
287304 """
288305 stats_list = []
289- for namemap in self ._stat_groups :
290- d = {statname : stats_dict [sname ] for sname , statname in namemap if sname in stats_dict }
306+ for group in self ._stat_groups :
307+ d = {}
308+ for fname , sname , is_obj in group :
309+ if fname not in stats_dict :
310+ continue
311+ d [sname ] = stats_dict [fname ]
291312 stats_list .append (d )
292313 return stats_list
0 commit comments