Skip to content

Commit 610b380

Browse files
Specify known data shape in kalman filter
1 parent 2999be1 commit 610b380

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def build_graph(
200200
self.n_endog = Z_shape[-2]
201201

202202
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
203-
203+
data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
204204
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
205205
params, PARAM_NAMES
206206
)
@@ -658,7 +658,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
658658
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
659659
# [0, L_pred]]
660660
# The Schur decomposition of this matrix will be B (upper triangular). We are
661-
# more insterested in B^T:
661+
# more interested in B^T:
662662
# Structure of B^T = [[chol(F), 0 ],
663663
# [K @ chol(F), chol(P_filtered)]
664664
zeros = pt.zeros((self.n_states, self.n_endog))

0 commit comments

Comments
 (0)