@@ -1880,35 +1880,30 @@ def logp(value, mu, rowchol, colchol):
18801880 return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
18811881
18821882
1883- class KroneckerNormalRV (RandomVariable ):
1884- name = "kroneckernormal"
1883+ class KroneckerNormalRV (SymbolicRandomVariable ):
18851884 ndim_supp = 1
1886- ndims_params = [1 , 0 , 2 ]
1887- dtype = "floatX"
18881885 _print_name = ("KroneckerNormal" , "\\ operatorname{KroneckerNormal}" )
18891886
1890- def _supp_shape_from_params (self , dist_params , param_shapes = None ):
1891- return supp_shape_from_ref_param_shape (
1892- ndim_supp = self .ndim_supp ,
1893- dist_params = dist_params ,
1894- param_shapes = param_shapes ,
1895- ref_param_idx = 0 ,
1896- )
1897-
1898- def rng_fn (self , rng , mu , sigma , * covs , size = None ):
1899- size = size if size else covs [- 1 ]
1900- covs = covs [:- 1 ] if covs [- 1 ] == size else covs
1901-
1902- cov = reduce (scipy .linalg .kron , covs )
1903-
1904- if sigma :
1905- cov = cov + sigma ** 2 * np .eye (cov .shape [0 ])
1887+ @classmethod
1888+ def rv_op (cls , mu , sigma , * covs , size = None , rng = None ):
1889+ mu = pt .as_tensor (mu )
1890+ sigma = pt .as_tensor (sigma )
1891+ covs = [pt .as_tensor (cov ) for cov in covs ]
1892+ rng = normalize_rng_param (rng )
1893+ size = normalize_size_param (size )
19061894
1907- x = multivariate_normal .rng_fn (rng = rng , mean = mu , cov = cov , size = size )
1908- return x
1895+ cov = reduce (pt .linalg .kron , covs )
1896+ cov = cov + sigma ** 2 * pt .eye (cov .shape [- 2 ])
1897+ next_rng , draws = multivariate_normal (mean = mu , cov = cov , size = size , rng = rng ).owner .outputs
19091898
1899+ covs_sig = "," .join (f"(a{ i } ,b{ i } )" for i in range (len (covs )))
1900+ signature = f"[rng],[size],(m),(),{ covs_sig } ->[rng],(m)"
19101901
1911- kroneckernormal = KroneckerNormalRV ()
1902+ return KroneckerNormalRV (
1903+ inputs = [rng , size , mu , sigma , * covs ],
1904+ outputs = [next_rng , draws ],
1905+ signature = signature ,
1906+ )(rng , size , mu , sigma , * covs )
19121907
19131908
19141909class KroneckerNormal (Continuous ):
@@ -1999,7 +1994,8 @@ class KroneckerNormal(Continuous):
19991994 .. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
20001995 """
20011996
2002- rv_op = kroneckernormal
1997+ rv_type = KroneckerNormalRV
1998+ rv_op = KroneckerNormalRV .rv_op
20031999
20042000 @classmethod
20052001 def dist (cls , mu , covs = None , chols = None , evds = None , sigma = None , * args , ** kwargs ):
@@ -2024,14 +2020,10 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
20242020
20252021 return super ().dist ([mu , sigma , * covs ], ** kwargs )
20262022
2027- def support_point (rv , size , mu , covs , chols , evds ):
2028- mean = mu
2029- if not rv_size_is_none (size ):
2030- support_point_size = pt .concatenate ([size , mu .shape ])
2031- mean = pt .full (support_point_size , mu )
2032- return mean
2023+ def support_point (rv , rng , size , mu , sigma , * covs ):
2024+ return pt .full_like (rv , mu )
20332025
2034- def logp (value , mu , sigma , * covs ):
2026+ def logp (value , rng , size , mu , sigma , * covs ):
20352027 """
20362028 Calculate log-probability of Multivariate Normal distribution
20372029 with Kronecker-structured covariance at specified value.
0 commit comments