@@ -71,10 +71,10 @@ def __init__(self, w, comp_dists, *args, **kwargs):
7171
7272 super (Mixture , self ).__init__ (shape , dtype , defaults = defaults ,
7373 * args , ** kwargs )
74-
74+
7575 def _comp_logp (self , value ):
7676 comp_dists = self .comp_dists
77-
77+
7878 try :
7979 value_ = value if value .ndim > 1 else tt .shape_padright (value )
8080
@@ -85,14 +85,14 @@ def _comp_logp(self, value):
8585
8686 def _comp_means (self ):
8787 try :
88- return self .comp_dists .mean
88+ return tt . as_tensor_variable ( self .comp_dists .mean )
8989 except AttributeError :
9090 return tt .stack ([comp_dist .mean for comp_dist in self .comp_dists ],
9191 axis = 1 )
9292
9393 def _comp_modes (self ):
9494 try :
95- return self .comp_dists .mode
95+ return tt . as_tensor_variable ( self .comp_dists .mode )
9696 except AttributeError :
9797 return tt .stack ([comp_dist .mode for comp_dist in self .comp_dists ],
9898 axis = 1 )
@@ -137,7 +137,7 @@ def random_choice(*args, **kwargs):
137137 else :
138138 return np .squeeze (comp_samples [w_samples ])
139139
140-
140+
141141class NormalMixture (Mixture ):
142142 R"""
143143 Normal mixture log-likelihood
@@ -164,6 +164,6 @@ class NormalMixture(Mixture):
164164 def __init__ (self , w , mu , * args , ** kwargs ):
165165 _ , sd = get_tau_sd (tau = kwargs .pop ('tau' , None ),
166166 sd = kwargs .pop ('sd' , None ))
167-
167+
168168 super (NormalMixture , self ).__init__ (w , Normal .dist (mu , sd = sd ),
169169 * args , ** kwargs )
0 commit comments