1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import numpy as np
15+ import pytest
16+
17+ from pytensor .compile import SharedVariable
1418from pytensor .graph import Constant
1519
20+ from pymc import Deterministic
1621from pymc .data import Data
1722from pymc .distributions import HalfNormal , Normal
1823from pymc .model import Model
1924from pymc .model .transform .optimization import freeze_dims_and_data
2025
2126
22- def test_freeze_existing_rv_dims_and_data ():
27+ def test_freeze_dims_and_data ():
2328 with Model (coords = {"test_dim" : range (5 )}) as m :
24- std = Data ("std " , [1 ])
29+ std = Data ("test_data " , [1 ])
2530 x = HalfNormal ("x" , std , dims = ("test_dim" ,))
2631 y = Normal ("y" , shape = x .shape [0 ] + 1 )
2732
@@ -34,18 +39,96 @@ def test_freeze_existing_rv_dims_and_data():
3439 assert y_logp .type .shape == (None ,)
3540
3641 frozen_m = freeze_dims_and_data (m )
37- std , x , y = frozen_m ["std " ], frozen_m ["x" ], frozen_m ["y" ]
42+ data , x , y = frozen_m ["test_data " ], frozen_m ["x" ], frozen_m ["y" ]
3843 x_logp , y_logp = frozen_m .logp (sum = False )
39- assert isinstance (std , Constant )
44+ assert isinstance (data , Constant )
4045 assert x .type .shape == (5 ,)
4146 assert y .type .shape == (6 ,)
4247 assert x_logp .type .shape == (5 ,)
4348 assert y_logp .type .shape == (6 ,)
4449
50+ # Test trying to update a frozen data or dim raises an informative error
51+ with frozen_m :
52+ with pytest .raises (TypeError , match = "The variable `test_data` must be a `SharedVariable`" ):
53+ frozen_m .set_data ("test_data" , values = [2 ])
54+ with pytest .raises (
55+ TypeError , match = "The dim_length of `test_dim` must be a `SharedVariable`"
56+ ):
57+ frozen_m .set_dim ("test_dim" , new_length = 6 , coord_values = range (6 ))
58+
59+ # Test we can still update original model
60+ with m :
61+ m .set_data ("test_data" , values = [2 ])
62+ m .set_dim ("test_dim" , new_length = 6 , coord_values = range (6 ))
63+ assert m ["test_data" ].get_value () == [2 ]
64+ assert m .dim_lengths ["test_dim" ].get_value () == 6
4565
46- def test_freeze_rv_dims_nothing_to_change ():
66+
67+ def test_freeze_dims_nothing_to_change ():
4768 with Model (coords = {"test_dim" : range (5 )}) as m :
4869 x = HalfNormal ("x" , shape = (5 ,))
4970 y = Normal ("y" , shape = x .shape [0 ] + 1 )
5071
5172 assert m .point_logps () == freeze_dims_and_data (m ).point_logps ()
73+
74+
75+ def test_freeze_dims_and_data_subset ():
76+ with Model (coords = {"dim1" : range (3 ), "dim2" : range (5 )}) as m :
77+ data1 = Data ("data1" , [1 , 2 , 3 ], dims = "dim1" )
78+ data2 = Data ("data2" , [1 , 2 , 3 , 4 , 5 ], dims = "dim2" )
79+ var1 = Normal ("var1" , dims = "dim1" )
80+ var2 = Normal ("var2" , dims = "dim2" )
81+ x = data1 * var1
82+ y = data2 * var2
83+ det = Deterministic ("det" , x [:, None ] + y [None , :])
84+
85+ assert det .type .shape == (None , None )
86+
87+ new_m = freeze_dims_and_data (m , dims = ["dim1" ], data = [])
88+ assert new_m ["det" ].type .shape == (3 , None )
89+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
90+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
91+ assert isinstance (new_m ["data1" ], SharedVariable )
92+ assert isinstance (new_m ["data2" ], SharedVariable )
93+
94+ new_m = freeze_dims_and_data (m , dims = ["dim2" ], data = [])
95+ assert new_m ["det" ].type .shape == (None , 5 )
96+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
97+ assert isinstance (new_m .dim_lengths ["dim2" ], Constant ) and new_m .dim_lengths ["dim2" ].data == 5
98+ assert isinstance (new_m ["data1" ], SharedVariable )
99+ assert isinstance (new_m ["data2" ], SharedVariable )
100+
101+ new_m = freeze_dims_and_data (m , dims = ["dim1" , "dim2" ], data = [])
102+ assert new_m ["det" ].type .shape == (3 , 5 )
103+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
104+ assert isinstance (new_m .dim_lengths ["dim2" ], Constant ) and new_m .dim_lengths ["dim2" ].data == 5
105+ assert isinstance (new_m ["data1" ], SharedVariable )
106+ assert isinstance (new_m ["data2" ], SharedVariable )
107+
108+ new_m = freeze_dims_and_data (m , dims = [], data = ["data1" ])
109+ assert new_m ["det" ].type .shape == (3 , None )
110+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
111+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
112+ assert isinstance (new_m ["data1" ], Constant ) and np .all (new_m ["data1" ].data == [1 , 2 , 3 ])
113+ assert isinstance (new_m ["data2" ], SharedVariable )
114+
115+ new_m = freeze_dims_and_data (m , dims = [], data = ["data2" ])
116+ assert new_m ["det" ].type .shape == (None , 5 )
117+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
118+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
119+ assert isinstance (new_m ["data1" ], SharedVariable )
120+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
121+
122+ new_m = freeze_dims_and_data (m , dims = [], data = ["data1" , "data2" ])
123+ assert new_m ["det" ].type .shape == (3 , 5 )
124+ assert isinstance (new_m .dim_lengths ["dim1" ], SharedVariable )
125+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
126+ assert isinstance (new_m ["data1" ], Constant ) and np .all (new_m ["data1" ].data == [1 , 2 , 3 ])
127+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
128+
129+ new_m = freeze_dims_and_data (m , dims = ["dim1" ], data = ["data2" ])
130+ assert new_m ["det" ].type .shape == (3 , 5 )
131+ assert isinstance (new_m .dim_lengths ["dim1" ], Constant ) and new_m .dim_lengths ["dim1" ].data == 3
132+ assert isinstance (new_m .dim_lengths ["dim2" ], SharedVariable )
133+ assert isinstance (new_m ["data1" ], SharedVariable )
134+ assert isinstance (new_m ["data2" ], Constant ) and np .all (new_m ["data2" ].data == [1 , 2 , 3 , 4 , 5 ])
0 commit comments