@@ -143,16 +143,16 @@ def pandas_to_array(data):
143143
144144
145145def change_rv_size (
146- rv_var : TensorVariable ,
146+ rv : TensorVariable ,
147147 new_size : PotentialShapeType ,
148148 expand : Optional [bool ] = False ,
149149) -> TensorVariable :
150150 """Change or expand the size of a `RandomVariable`.
151151
152152 Parameters
153153 ==========
154- rv_var
155- The `RandomVariable` output.
154+ rv
155+ The old `RandomVariable` output.
156156 new_size
157157 The new size.
158158 expand:
@@ -167,32 +167,32 @@ def change_rv_size(
167167 new_size = (new_size ,)
168168
169169 # Extract the RV node that is to be resized, together with its inputs, name and tag
170- if isinstance (rv_var .owner .op , SpecifyShape ):
171- rv_var = rv_var .owner .inputs [0 ]
172- rv_node = rv_var .owner
170+ if isinstance (rv .owner .op , SpecifyShape ):
171+ rv = rv .owner .inputs [0 ]
172+ rv_node = rv .owner
173173 rng , size , dtype , * dist_params = rv_node .inputs
174- name = rv_var .name
175- tag = rv_var .tag
174+ name = rv .name
175+ tag = rv .tag
176176
177177 if expand :
178- old_shape = tuple (rv_node .op ._infer_shape (size , dist_params ))
179- old_size = old_shape [: len (old_shape ) - rv_node .op .ndim_supp ]
180- new_size = tuple (new_size ) + tuple (old_size )
178+ shape = tuple (rv_node .op ._infer_shape (size , dist_params ))
179+ size = shape [: len (shape ) - rv_node .op .ndim_supp ]
180+ new_size = tuple (new_size ) + tuple (size )
181181
182182 # Make sure the new size is a tensor. This dtype-aware conversion helps
183183 # to not unnecessarily pick up a `Cast` in some cases (see #4652).
184184 new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
185185
186186 new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
187- rv_var = new_rv_node .outputs [- 1 ]
188- rv_var .name = name
187+ new_rv = new_rv_node .outputs [- 1 ]
188+ new_rv .name = name
189189 for k , v in tag .__dict__ .items ():
190- rv_var .tag .__dict__ .setdefault (k , v )
190+ new_rv .tag .__dict__ .setdefault (k , v )
191191
192192 if config .compute_test_value != "off" :
193193 compute_test_value (new_rv_node )
194194
195- return rv_var
195+ return new_rv
196196
197197
198198def extract_rv_and_value_vars (
0 commit comments