|
25 | 25 | from pytensor.graph.basic import Apply, Variable |
26 | 26 | from pytensor.graph.op import Op |
27 | 27 | from pytensor.raise_op import Assert |
28 | | -from pytensor.sparse.basic import DenseFromSparse, sp_sum |
| 28 | +from pytensor.sparse.basic import DenseFromSparse |
| 29 | +from pytensor.sparse.math import sp_sum |
29 | 30 | from pytensor.tensor import ( |
30 | 31 | TensorConstant, |
31 | 32 | TensorVariable, |
@@ -2263,10 +2264,12 @@ class CAR(Continuous): |
2263 | 2264 | def dist(cls, mu, W, alpha, tau, *args, **kwargs): |
2264 | 2265 | # This variable has an expensive validation check, that we want to constant-fold if possible |
2265 | 2266 | # So it's passed as an explicit input |
2266 | | - W = pytensor.sparse.as_sparse_or_tensor_variable(W) |
| 2267 | + from pytensor.sparse import as_sparse_or_tensor_variable, structured_sign |
| 2268 | + |
| 2269 | + W = as_sparse_or_tensor_variable(W) |
2267 | 2270 | if isinstance(W.type, pytensor.sparse.SparseTensorType): |
2268 | | - abs_diff = pytensor.sparse.basic.mul(pytensor.sparse.sign(W - W.T), W - W.T) |
2269 | | - W_is_valid = pt.isclose(pytensor.sparse.sp_sum(abs_diff), 0) |
| 2271 | + abs_diff = structured_sign(W - W.T) * (W - W.T) |
| 2272 | + W_is_valid = pt.isclose(abs_diff.sum(), 0) |
2270 | 2273 | else: |
2271 | 2274 | W_is_valid = pt.allclose(W, W.T) |
2272 | 2275 |
|
@@ -2307,7 +2310,7 @@ def logp(value, mu, W, alpha, tau, W_is_valid): |
2307 | 2310 | if W.owner and isinstance(W.owner.op, DenseFromSparse): |
2308 | 2311 | W = W.owner.inputs[0] |
2309 | 2312 |
|
2310 | | - sparse = isinstance(W, pytensor.sparse.SparseVariable) |
| 2313 | + sparse = isinstance(W, pytensor.sparse.variable.SparseVariable) |
2311 | 2314 | if sparse: |
2312 | 2315 | D = sp_sum(W, axis=0) |
2313 | 2316 | Dinv_sqrt = pt.diag(1 / pt.sqrt(D)) |
|
0 commit comments