1212import numpy as np
1313from functools import wraps
1414
15- __all__ = ['Model' , 'compilef' , 'gradient' , 'hessian' , 'withmodel ' , 'Point' ]
15+ __all__ = ['Model' , 'compilef' , 'gradient' , 'hessian' , 'modelcontext ' , 'Point' ]
1616
1717
1818
@@ -38,36 +38,10 @@ def get_context(cls):
3838 except IndexError :
3939 raise TypeError ("No context on context stack" )
4040
41- def withcontext (contexttype , argname ):
42- """
43- Returns a decorator for wrapping functions so they look for an argument in a specific argument slot.
44- If not found, the decorated function searches the for a context and inserts it in that slot.
45-
46- Parameters
47- ----------
48- contexttype : type
49- The type of context to search for
50- argname : string
51- The name of the argument slot where the context should go
52-
53- Returns
54- -------
55- decorator function
56-
57- """
58- def decorator (fn ):
59- n = list (fn .func_code .co_varnames ).index (argname )
60-
61- @wraps (fn )
62- def nfn (* args , ** kwargs ):
63- if not (len (args ) > n and isinstance (args [n ], contexttype )):
64- context = contexttype .get_context ()
65- args = args [:n ] + (context ,) + args [n :]
66- return fn (* args ,** kwargs )
67-
68- return nfn
69- return decorator
70-
41+ def modelcontext (model ):
42+ if model is None :
43+ return Model .get_context ()
44+ return model
7145
7246class Model (Context ):
7347 """
@@ -142,20 +116,22 @@ def TransformedVar(model, name, dist, trans):
142116 def AddPotential (model , potential ):
143117 model .factors .append (potential )
144118
145- withmodel = withcontext (Model , 'model' )
146119
147- @withmodel
148- def Point (model , * args ,** kwargs ):
120+ def Point (* args ,** kwargs ):
149121 """
150122 Build a point. Uses same args as dict() does.
151123 Filters out variables not in the model. All keys are strings.
152124
153125 Parameters
154126 ----------
155- model : Model (in context)
156127 *args, **kwargs
157128 arguments to build a dict
158129 """
130+ if 'model' in kwargs :
131+ model = kwargs ['model' ]
132+ del kwargs ['model' ]
133+ else :
134+ model = Model .get_context ()
159135
160136 d = dict (* args , ** kwargs )
161137 varnames = map (str , model .vars )
0 commit comments