@@ -90,14 +90,16 @@ def model_free_rv(rv, value, transform, *dims):
9090
9191
9292def toposort_replace (
93- fgraph : FunctionGraph , replacements : Sequence [Tuple [Variable , Variable ]]
93+ fgraph : FunctionGraph , replacements : Sequence [Tuple [Variable , Variable ]], reverse : bool = False
9494) -> None :
9595 """Replace multiple variables in topological order."""
9696 toposort = fgraph .toposort ()
9797 sorted_replacements = sorted (
98- replacements , key = lambda pair : toposort .index (pair [0 ].owner ) if pair [0 ].owner else - 1
98+ replacements ,
99+ key = lambda pair : toposort .index (pair [0 ].owner ) if pair [0 ].owner else - 1 ,
100+ reverse = reverse ,
99101 )
100- fgraph .replace_all (tuple ( sorted_replacements ) , import_missing = True )
102+ fgraph .replace_all (sorted_replacements , import_missing = True )
101103
102104
103105@node_rewriter ([Elemwise ])
@@ -109,11 +111,20 @@ def local_remove_identity(fgraph, node):
109111remove_identity_rewrite = out2in (local_remove_identity )
110112
111113
112- def fgraph_from_model (model : Model ) -> Tuple [FunctionGraph , Dict [Variable , Variable ]]:
114+ def fgraph_from_model (
115+ model : Model , inlined_views = False
116+ ) -> Tuple [FunctionGraph , Dict [Variable , Variable ]]:
113117 """Convert Model to FunctionGraph.
114118
115119 See: model_from_fgraph
116120
121+ Parameters
122+ ----------
123+ model: PyMC model
124+ inlined_views: bool, default False
125+ Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph,
126+ or show up as separate branches.
127+
117128 Returns
118129 -------
119130 fgraph: FunctionGraph
@@ -138,19 +149,36 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
138149 free_rvs = model .free_RVs
139150 observed_rvs = model .observed_RVs
140151 potentials = model .potentials
152+ named_vars = model .named_vars .values ()
141153 # We copy Deterministics (Identity Op) so that they don't show in between "main" variables
142154 # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
143155 old_deterministics = model .deterministics
144- deterministics = [det .copy (det .name ) for det in old_deterministics ]
145- # Other variables that are in model.named_vars but are not any of the categories above
156+ deterministics = [det if inlined_views else det .copy (det .name ) for det in old_deterministics ]
157+ # Value variables (we also have to decide whether to inline named ones)
158+ old_value_vars = list (rvs_to_values .values ())
159+ unnamed_value_vars = [val for val in old_value_vars if val not in named_vars ]
160+ named_value_vars = [
161+ val if inlined_views else val .copy (val .name ) for val in old_value_vars if val in named_vars
162+ ]
163+ value_vars = old_value_vars .copy ()
164+ if inlined_views :
165+ # In this case we want to use the named_value_vars as the value_vars in RVs
166+ for named_val in named_value_vars :
167+ idx = value_vars .index (named_val )
168+ value_vars [idx ] = named_val
169+ # Other variables that are in named_vars but are not any of the categories above
146170 # E.g., MutableData, ConstantData, _dim_lengths
147171 # We use the same trick as deterministics!
148- accounted_for = free_rvs + observed_rvs + potentials + old_deterministics
149- old_other_named_vars = [var for var in model .named_vars .values () if var not in accounted_for ]
150- other_named_vars = [var .copy (var .name ) for var in old_other_named_vars ]
151- value_vars = [val for val in rvs_to_values .values () if val not in old_other_named_vars ]
172+ accounted_for = set (free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars )
173+ other_named_vars = [
174+ var if inlined_views else var .copy (var .name )
175+ for var in named_vars
176+ if var not in accounted_for
177+ ]
152178
153- model_vars = rvs + potentials + deterministics + other_named_vars + value_vars
179+ model_vars = (
180+ rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars
181+ )
154182
155183 memo = {}
156184
@@ -176,13 +204,13 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
176204
177205 # Introduce dummy `ModelVar` Ops
178206 free_rvs_to_transforms = {memo [k ]: tr for k , tr in rvs_to_transforms .items ()}
179- free_rvs_to_values = {memo [k ]: memo [v ] for k , v in rvs_to_values . items ( ) if k in free_rvs }
207+ free_rvs_to_values = {memo [k ]: memo [v ] for k , v in zip ( rvs , value_vars ) if k in free_rvs }
180208 observed_rvs_to_values = {
181- memo [k ]: memo [v ] for k , v in rvs_to_values . items ( ) if k in observed_rvs
209+ memo [k ]: memo [v ] for k , v in zip ( rvs , value_vars ) if k in observed_rvs
182210 }
183211 potentials = [memo [k ] for k in potentials ]
184212 deterministics = [memo [k ] for k in deterministics ]
185- other_named_vars = [memo [k ] for k in other_named_vars ]
213+ named_vars = [memo [k ] for k in other_named_vars + named_value_vars ]
186214
187215 vars = fgraph .outputs
188216 new_vars = []
@@ -198,31 +226,31 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
198226 new_var = model_potential (var , * dims )
199227 elif var in deterministics :
200228 new_var = model_deterministic (var , * dims )
201- elif var in other_named_vars :
229+ elif var in named_vars :
202230 new_var = model_named (var , * dims )
203231 else :
204- # Value variables
232+ # Unnamed value variables
205233 new_var = var
206234 new_vars .append (new_var )
207235
208236 replacements = tuple (zip (vars , new_vars ))
209- toposort_replace (fgraph , replacements )
237+ toposort_replace (fgraph , replacements , reverse = True )
210238
211239 # Reference model vars in memo
212240 inverse_memo = {v : k for k , v in memo .items ()}
213241 for var , model_var in replacements :
214- if isinstance (
215- model_var .owner is not None and model_var .owner .op , (ModelDeterministic , ModelNamed )
242+ if not inlined_views and (
243+ model_var .owner and isinstance ( model_var .owner .op , (ModelDeterministic , ModelNamed ) )
216244 ):
217245 # Ignore extra identity that will be removed at the end
218246 var = var .owner .inputs [0 ]
219247 original_var = inverse_memo [var ]
220248 memo [original_var ] = model_var
221249
222- # Remove value variable as outputs, now that they are graph inputs
223- first_value_idx = len (fgraph .outputs ) - len (value_vars )
224- for _ in value_vars :
225- fgraph .remove_output (first_value_idx )
250+ # Remove the last outputs corresponding to unnamed value variables , now that they are graph inputs
251+ first_idx_to_remove = len (fgraph .outputs ) - len (unnamed_value_vars )
252+ for _ in unnamed_value_vars :
253+ fgraph .remove_output (first_idx_to_remove )
226254
227255 # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
228256 remove_identity_rewrite .apply (fgraph )
0 commit comments