diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 1d46fa91cd..84bab0d7ea 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 3d0fbcf819..e7284f77b1 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base>=0.7.0 - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index c47b53946b..c0eb0e7b8c 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz_base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..94b5bf69aa 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -28,8 +28,11 @@ import numpy as np import xarray -from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires +from arviz import InferenceData, concat +from arviz.data.base import CoordSpec, DimSpec +from arviz_base import dict_to_dataset +from arviz_base.base import requires +from arviz_base.rcparams import RcParams from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -211,7 +214,7 @@ def __init__( save_warmup: bool | None = None, include_transformed: bool = False, ): - self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.save_warmup = RcParams["data.save_warmup"] if save_warmup is None else save_warmup self.include_transformed = include_transformed self.trace = trace @@ -305,14 +308,14 @@ def posterior_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, @@ -347,14 +350,14 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, @@ -367,7 +370,11 @@ def posterior_predictive_to_xarray(self): data = self.posterior_predictive dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) @requires(["predictions"]) @@ -376,7 +383,11 @@ def predictions_to_xarray(self): data = self.predictions dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) def priors_to_xarray(self): @@ -399,7 +410,7 @@ def priors_to_xarray(self): if var_names is None else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, ) @@ -414,10 +425,10 @@ def observed_data_to_xarray(self): return None return dict_to_dataset( self.observations, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) @requires("model") @@ -429,10 +440,10 @@ def constant_data_to_xarray(self): xarray_dataset = dict_to_dataset( constant_data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) # provisional handling of scalars in constant @@ -707,9 +718,9 @@ def apply_function_over_dataset( return dict_to_dataset( out_trace, - library=pymc, + inference_library=pymc, dims=dims, coords=coords, - default_dims=list(sample_dims), + sample_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 249f6c5253..08d383db15 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -267,7 +267,9 @@ def _save_sample_stats( sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, - library=pymc, + inference_library=pymc, + sample_dims=["chain"], + check_conventions=False, ) ikwargs: dict[str, Any] = {"model": model} diff --git a/requirements-dev.txt b/requirements-dev.txt index 0c8818d531..007b692c6c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. +arviz-base>=0.7.0 arviz>=0.13.0 cachetools>=4.2.1 cloudpickle