1717
1818import warnings
1919
20- from functools import reduce
20+ from functools import partial , reduce
2121from typing import Optional
2222
2323import numpy as np
3030from pytensor .raise_op import Assert
3131from pytensor .sparse .basic import sp_sum
3232from pytensor .tensor import TensorConstant , gammaln , sigmoid
33- from pytensor .tensor .nlinalg import det , eigh , matrix_inverse , trace
33+ from pytensor .tensor .linalg import cholesky , det , eigh
34+ from pytensor .tensor .linalg import inv as matrix_inverse
35+ from pytensor .tensor .linalg import solve_triangular , trace
3436from pytensor .tensor .random .basic import dirichlet , multinomial , multivariate_normal
3537from pytensor .tensor .random .op import RandomVariable
3638from pytensor .tensor .random .utils import (
3739 broadcast_params ,
3840 supp_shape_from_ref_param_shape ,
3941)
40- from pytensor .tensor .slinalg import Cholesky , SolveTriangular
4142from pytensor .tensor .type import TensorType
42- from scipy import linalg , stats
43+ from scipy import stats
4344
4445import pymc as pm
4546
9394 "StickBreakingWeights" ,
9495]
9596
96- solve_lower = SolveTriangular ( lower = True )
97- solve_upper = SolveTriangular ( lower = False )
97+ solve_lower = partial ( solve_triangular , lower = True )
98+ solve_upper = partial ( solve_triangular , lower = False )
9899
99100
100101class SimplexContinuous (Continuous ):
@@ -110,7 +111,7 @@ def simplex_cont_transform(op, rv):
110111# moment. We work around that by using a cholesky op
111112# that returns a nan as first entry instead of raising
112113# an error.
113- cholesky = Cholesky ( lower = True , on_error = "nan" )
114+ nan_lower_cholesky = partial ( cholesky , lower = True , on_error = "nan" )
114115
115116
116117def quaddist_matrix (cov = None , chol = None , tau = None , lower = True , * args , ** kwargs ):
@@ -155,7 +156,7 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
155156 onedim = False
156157
157158 delta = value - mu
158- chol_cov = cholesky (cov )
159+ chol_cov = nan_lower_cholesky (cov )
159160 if mat_type != "tau" :
160161 dist , logdet , ok = quaddist_chol (delta , chol_cov )
161162 else :
@@ -847,9 +848,9 @@ def dist(cls, *args, **kwargs):
847848
848849def posdef (AA ):
849850 try :
850- linalg .cholesky (AA )
851+ scipy . linalg .cholesky (AA )
851852 return True
852- except linalg .LinAlgError :
853+ except scipy . linalg .LinAlgError :
853854 return False
854855
855856
@@ -1073,7 +1074,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10731074 if initval is not None :
10741075 # Inverse transform
10751076 initval = np .dot (np .dot (np .linalg .inv (L ), initval ), np .linalg .inv (L .T ))
1076- initval = linalg .cholesky (initval , lower = True )
1077+ initval = scipy . linalg .cholesky (initval , lower = True )
10771078 diag_testval = initval [diag_idx ] ** 2
10781079 tril_testval = initval [tril_idx ]
10791080 else :
@@ -1785,7 +1786,7 @@ def dist(
17851786 * args ,
17861787 ** kwargs ,
17871788 ):
1788- cholesky = Cholesky ( lower = True , on_error = "raise" )
1789+ lower_cholesky = partial ( cholesky , lower = True , on_error = "raise" )
17891790
17901791 # Among-row matrices
17911792 if len ([i for i in [rowcov , rowchol ] if i is not None ]) != 1 :
@@ -1795,7 +1796,7 @@ def dist(
17951796 if rowcov is not None :
17961797 if rowcov .ndim != 2 :
17971798 raise ValueError ("rowcov must be two dimensional." )
1798- rowchol_cov = cholesky (rowcov )
1799+ rowchol_cov = lower_cholesky (rowcov )
17991800 else :
18001801 if rowchol .ndim != 2 :
18011802 raise ValueError ("rowchol must be two dimensional." )
@@ -1810,7 +1811,7 @@ def dist(
18101811 colcov = pt .as_tensor_variable (colcov )
18111812 if colcov .ndim != 2 :
18121813 raise ValueError ("colcov must be two dimensional." )
1813- colchol_cov = cholesky (colcov )
1814+ colchol_cov = lower_cholesky (colcov )
18141815 else :
18151816 if colchol .ndim != 2 :
18161817 raise ValueError ("colchol must be two dimensional." )
@@ -1851,10 +1852,10 @@ def logp(value, mu, rowchol, colchol):
18511852
18521853 # Find exponent piece by piece
18531854 right_quaddist = solve_lower (rowchol , delta )
1854- quaddist = pt .nlinalg .matrix_dot (right_quaddist .T , right_quaddist )
1855+ quaddist = pt .linalg .matrix_dot (right_quaddist .T , right_quaddist )
18551856 quaddist = solve_lower (colchol , quaddist )
18561857 quaddist = solve_upper (colchol .T , quaddist )
1857- trquaddist = pt .nlinalg .trace (quaddist )
1858+ trquaddist = pt .linalg .trace (quaddist )
18581859
18591860 coldiag = pt .diag (colchol )
18601861 rowdiag = pt .diag (rowchol )
@@ -1887,7 +1888,7 @@ def rng_fn(self, rng, mu, sigma, *covs, size=None):
18871888 size = size if size else covs [- 1 ]
18881889 covs = covs [:- 1 ] if covs [- 1 ] == size else covs
18891890
1890- cov = reduce (linalg .kron , covs )
1891+ cov = reduce (scipy . linalg .kron , covs )
18911892
18921893 if sigma :
18931894 cov = cov + sigma ** 2 * np .eye (cov .shape [0 ])
@@ -1930,7 +1931,7 @@ class KroneckerNormal(Continuous):
19301931 :math:`[(v_1, Q_1), (v_2, Q_2), ...]` such that
19311932 :math:`K_i = Q_i \text{diag}(v_i) Q_i'`. For example::
19321933
1933- v_i, Q_i = pt.nlinalg .eigh(K_i)
1934+ v_i, Q_i = pt.linalg .eigh(K_i)
19341935 sigma : scalar, optional
19351936 Standard deviation of the Gaussian white noise.
19361937
@@ -2228,7 +2229,7 @@ def logp(value, mu, W, alpha, tau):
22282229 D = W .sum (axis = 0 )
22292230 Dinv_sqrt = pt .diag (1 / pt .sqrt (D ))
22302231 DWD = pt .dot (pt .dot (Dinv_sqrt , W ), Dinv_sqrt )
2231- lam = pt .slinalg .eigvalsh (DWD , pt .eye (DWD .shape [0 ]))
2232+ lam = pt .linalg .eigvalsh (DWD , pt .eye (DWD .shape [0 ]))
22322233
22332234 d , _ = W .shape
22342235
0 commit comments