33import itertools
44import threading
55import warnings
6- from typing import Optional
6+ from typing import Optional , TypeVar , Type , List , Union , TYPE_CHECKING , Any , cast
7+ from sys import modules
78
89import numpy as np
910from pandas import Series
@@ -55,10 +56,10 @@ def __call__(self, *args, **kwargs):
5556 return getattr (self .obj , self .method_name )(* args , ** kwargs )
5657
5758
58- def incorporate_methods (source , destination , methods , default = None ,
59+ def incorporate_methods (source , destination , methods ,
5960 wrapper = None , override = False ):
6061 """
61- Add attributes to a destination object which points to
62+ Add attributes to a destination object which point to
6263 methods from from a source object.
6364
6465 Parameters
@@ -69,8 +70,6 @@ def incorporate_methods(source, destination, methods, default=None,
6970 The destination object for the methods.
7071 methods : list of str
7172 Names of methods to incorporate.
72- default : object
73- The value used if the source does not have one of the listed methods.
7473 wrapper : function
7574 An optional function to allow the source method to be
7675 wrapped. Should take the form my_wrapper(source, method_name)
@@ -162,49 +161,131 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162161 node_children .update (temp_tree )
163162 return leaf_nodes , node_parents , node_children
164163
164+ T = TypeVar ('T' , bound = 'ContextMeta' )
165165
166- class Context :
166+
167+ class ContextMeta (type ):
167168 """Functionality for objects that put themselves in a context using
168169 the `with` statement.
169170 """
170- contexts = threading .local ()
171-
172- def __enter__ (self ):
173- type (self ).get_contexts ().append (self )
174- # self._theano_config is set in Model.__new__
175- if hasattr (self , '_theano_config' ):
176- self ._old_theano_config = set_theano_conf (self ._theano_config )
177- return self
178-
179- def __exit__ (self , typ , value , traceback ):
180- type (self ).get_contexts ().pop ()
181- # self._theano_config is set in Model.__new__
182- if hasattr (self , '_old_theano_config' ):
183- set_theano_conf (self ._old_theano_config )
184171
185- @classmethod
186- def get_contexts (cls ):
187- # no race-condition here, cls.contexts is a thread-local object
172+ def __new__ (cls , name , bases , dct , ** kargs ): # pylint: disable=unused-argument
173+ "Add __enter__ and __exit__ methods to the class."
174+ def __enter__ (self ):
175+ self .__class__ .context_class .get_contexts ().append (self )
176+ # self._theano_config is set in Model.__new__
177+ if hasattr (self , '_theano_config' ):
178+ self ._old_theano_config = set_theano_conf (self ._theano_config )
179+ return self
180+
181+ def __exit__ (self , typ , value , traceback ): # pylint: disable=unused-argument
182+ self .__class__ .context_class .get_contexts ().pop ()
183+ # self._theano_config is set in Model.__new__
184+ if hasattr (self , '_old_theano_config' ):
185+ set_theano_conf (self ._old_theano_config )
186+
187+ dct [__enter__ .__name__ ] = __enter__
188+ dct [__exit__ .__name__ ] = __exit__
189+
190+ # We strip off keyword args, per the warning from
191+ # StackExchange:
192+ # DO NOT send "**kargs" to "type.__new__". It won't catch them and
193+ # you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
194+ return super ().__new__ (cls , name , bases , dct )
195+
196+ # FIXME: is there a more elegant way to automatically add methods to the class that
197+ # are instance methods instead of class methods?
198+ def __init__ (cls , name , bases , nmspc , context_class : Optional [Type ]= None , ** kwargs ): # pylint: disable=unused-argument
199+ """Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
200+ if context_class is not None :
201+ cls ._context_class = context_class
202+ super ().__init__ (name , bases , nmspc )
203+
204+
205+
206+ def get_context (cls , error_if_none = True ) -> Optional [T ]:
207+ """Return the most recently pushed context object of type ``cls``
208+ on the stack, or ``None``. If ``error_if_none`` is True (default),
209+ raise a ``TypeError`` instead of returning ``None``."""
210+ idx = - 1
211+ while True :
212+ try :
213+ candidate = cls .get_contexts ()[idx ] # type: Optional[T]
214+ except IndexError as e :
215+ # Calling code expects to get a TypeError if the entity
216+ # is unfound, and there's too much to fix.
217+ if error_if_none :
218+ raise TypeError ("No %s on context stack" % str (cls ))
219+ return None
220+ return candidate
221+ idx = idx - 1
222+
223+ def get_contexts (cls ) -> List [T ]:
224+ """Return a stack of context instances for the ``context_class``
225+ of ``cls``."""
226+ # This lazily creates the context class's contexts
227+ # thread-local object, as needed. This seems inelegant to me,
228+ # but since the context class is not guaranteed to exist when
229+ # the metaclass is being instantiated, I couldn't figure out a
230+ # better way. [2019/10/11:rpg]
231+
232+ # no race-condition here, contexts is a thread-local object
188233 # be sure not to override contexts in a subclass however!
189- if not hasattr (cls .contexts , 'stack' ):
190- cls .contexts .stack = []
191- return cls .contexts .stack
192-
193- @classmethod
194- def get_context (cls ):
195- """Return the deepest context on the stack."""
196- try :
197- return cls .get_contexts ()[- 1 ]
198- except IndexError :
199- raise TypeError ("No context on context stack" )
234+ context_class = cls .context_class
235+ assert isinstance (context_class , type ), \
236+ "Name of context class, %s was not resolvable to a class" % context_class
237+ if not hasattr (context_class , 'contexts' ):
238+ context_class .contexts = threading .local ()
239+
240+ contexts = context_class .contexts
241+
242+ if not hasattr (contexts , 'stack' ):
243+ contexts .stack = []
244+ return contexts .stack
245+
246+ # the following complex property accessor is necessary because the
247+ # context_class may not have been created at the point it is
248+ # specified, so the context_class may be a class *name* rather
249+ # than a class.
250+ @property
251+ def context_class (cls ) -> Type :
252+ def resolve_type (c : Union [Type , str ]) -> Type :
253+ if isinstance (c , str ):
254+ c = getattr (modules [cls .__module__ ], c )
255+ if isinstance (c , type ):
256+ return c
257+ raise ValueError ("Cannot resolve context class %s" % c )
258+ assert cls is not None
259+ if isinstance (cls ._context_class , str ):
260+ cls ._context_class = resolve_type (cls ._context_class )
261+ if not isinstance (cls ._context_class , (str , type )):
262+ raise ValueError ("Context class for %s, %s, is not of the right type" % \
263+ (cls .__name__ , cls ._context_class ))
264+ return cls ._context_class
265+
266+ # Inherit context class from parent
267+ def __init_subclass__ (cls , ** kwargs ):
268+ super ().__init_subclass__ (** kwargs )
269+ cls .context_class = super ().context_class
270+
271+ # Initialize object in its own context...
272+ # Merged from InitContextMeta in the original.
273+ def __call__ (cls , * args , ** kwargs ):
274+ instance = cls .__new__ (cls , * args , ** kwargs )
275+ with instance : # appends context
276+ instance .__init__ (* args , ** kwargs )
277+ return instance
200278
201279
202280def modelcontext (model : Optional ['Model' ]) -> 'Model' :
203- """return the given model or try to find it in the context if there was
204- none supplied.
281+ """
282+ Return the given model or, if none was supplied, try to find one in
283+ the context stack.
205284 """
206285 if model is None :
207- return Model .get_context ()
286+ model = Model .get_context (error_if_none = False )
287+ if model is None :
288+ raise ValueError ("No model on context stack." )
208289 return model
209290
210291
@@ -292,15 +373,6 @@ def logp_nojact(self):
292373 return logp
293374
294375
295- class InitContextMeta (type ):
296- """Metaclass that executes `__init__` of instance in it's context"""
297- def __call__ (cls , * args , ** kwargs ):
298- instance = cls .__new__ (cls , * args , ** kwargs )
299- with instance : # appends context
300- instance .__init__ (* args , ** kwargs )
301- return instance
302-
303-
304376def withparent (meth ):
305377 """Helper wrapper that passes calls to parent's instance"""
306378 def wrapped (self , * args , ** kwargs ):
@@ -346,11 +418,18 @@ def __setitem__(self, key, value):
346418 ' able to determine '
347419 'appropriate logic for it' )
348420
349- def __imul__ (self , other ):
421+ # Added this because mypy didn't like having __imul__ without __mul__
422+ # This is my best guess about what this should do. I might be happier
423+ # to kill both of these if they are not used.
424+ def __mul__ (self , other ) -> 'treelist' :
425+ return cast ('treelist' , list .__mul__ (self , other ))
426+
427+ def __imul__ (self , other ) -> 'treelist' :
350428 t0 = len (self )
351429 list .__imul__ (self , other )
352430 if self .parent is not None :
353431 self .parent .extend (self [t0 :])
432+ return self # python spec says should return the result.
354433
355434
356435class treedict (dict ):
@@ -555,7 +634,7 @@ def _build_joined(self, cost, args, vmap):
555634 return args_joined , theano .clone (cost , replace = replace )
556635
557636
558- class Model (Context , Factor , WithMemoization , metaclass = InitContextMeta ):
637+ class Model (Factor , WithMemoization , metaclass = ContextMeta , context_class = 'Model' ):
559638 """Encapsulates the variables and likelihood factors of a model.
560639
561640 Model class can be used for creating class based models. To create
@@ -643,15 +722,18 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
643722 CustomModel(mean=1, name='first')
644723 CustomModel(mean=2, name='second')
645724 """
725+
726+ if TYPE_CHECKING :
727+ def __enter__ (self : 'Model' ) -> 'Model' : ...
728+ def __exit__ (self : 'Model' , * exc : Any ) -> bool : ...
729+
646730 def __new__ (cls , * args , ** kwargs ):
647731 # resolves the parent instance
648732 instance = super ().__new__ (cls )
649733 if kwargs .get ('model' ) is not None :
650734 instance ._parent = kwargs .get ('model' )
651- elif cls .get_contexts ():
652- instance ._parent = cls .get_contexts ()[- 1 ]
653735 else :
654- instance ._parent = None
736+ instance ._parent = cls . get_context ( error_if_none = False )
655737 theano_config = kwargs .get ('theano_config' , None )
656738 if theano_config is None or 'compute_test_value' not in theano_config :
657739 theano_config = {'compute_test_value' : 'raise' }
@@ -694,7 +776,7 @@ def root(self):
694776 def isroot (self ):
695777 return self .parent is None
696778
697- @property
779+ @property # type: ignore -- mypy can't handle decorated types.
698780 @memoize (bound = True )
699781 def bijection (self ):
700782 vars = inputvars (self .vars )
0 commit comments