@@ -468,6 +468,7 @@ def predict(
468468 self ,
469469 X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
470470 extend_idata : bool = True ,
471+ ** kwargs ,
471472 ) -> np .ndarray :
472473 """
473474 Uses model to predict on unseen data and return point prediction of all the samples. The point prediction
@@ -479,6 +480,7 @@ def predict(
479480 The input data used for prediction.
480481 extend_idata : Boolean determining whether the predictions should be added to inference data object.
481482 Defaults to True.
483+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
482484
483485 Returns
484486 -------
@@ -495,7 +497,7 @@ def predict(
495497 """
496498
497499 posterior_predictive_samples = self .sample_posterior_predictive (
498- X_pred , extend_idata , combined = False
500+ X_pred , extend_idata , combined = False , ** kwargs
499501 )
500502
501503 if self .output_var not in posterior_predictive_samples :
@@ -514,6 +516,7 @@ def sample_prior_predictive(
514516 samples : Optional [int ] = None ,
515517 extend_idata : bool = False ,
516518 combined : bool = True ,
519+ ** kwargs ,
517520 ):
518521 """
519522 Sample from the model's prior predictive distribution.
@@ -529,6 +532,7 @@ def sample_prior_predictive(
529532 Defaults to False.
530533 combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
531534 Defaults to True.
535+ **kwargs: Additional arguments to pass to pymc.sample_prior_predictive
532536
533537 Returns
534538 -------
@@ -544,7 +548,7 @@ def sample_prior_predictive(
544548 self ._data_setter (X_pred )
545549 if self .model is not None :
546550 with self .model : # sample with new input data
547- prior_pred : az .InferenceData = pm .sample_prior_predictive (samples )
551+ prior_pred : az .InferenceData = pm .sample_prior_predictive (samples , ** kwargs )
548552 self .set_idata_attrs (prior_pred )
549553 if extend_idata :
550554 if self .idata is not None :
@@ -556,7 +560,7 @@ def sample_prior_predictive(
556560
557561 return prior_predictive_samples
558562
559- def sample_posterior_predictive (self , X_pred , extend_idata , combined ):
563+ def sample_posterior_predictive (self , X_pred , extend_idata , combined , ** kwargs ):
560564 """
561565 Sample from the model's posterior predictive distribution.
562566
@@ -568,6 +572,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
568572 Defaults to False.
569573 combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
570574 Defaults to True.
575+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
571576
572577 Returns
573578 -------
@@ -577,7 +582,7 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined):
577582 self ._data_setter (X_pred )
578583
579584 with self .model : # sample with new input data
580- post_pred = pm .sample_posterior_predictive (self .idata )
585+ post_pred = pm .sample_posterior_predictive (self .idata , ** kwargs )
581586 if extend_idata :
582587 self .idata .extend (post_pred )
583588
@@ -621,15 +626,17 @@ def predict_proba(
621626 X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
622627 extend_idata : bool = True ,
623628 combined : bool = False ,
629+ ** kwargs ,
624630 ) -> xr .DataArray :
625631 """Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators."""
626- return self .predict_posterior (X_pred , extend_idata , combined )
632+ return self .predict_posterior (X_pred , extend_idata , combined , ** kwargs )
627633
628634 def predict_posterior (
629635 self ,
630636 X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
631637 extend_idata : bool = True ,
632638 combined : bool = True ,
639+ ** kwargs ,
633640 ) -> xr .DataArray :
634641 """
635642 Generate posterior predictive samples on unseen data.
@@ -642,6 +649,7 @@ def predict_posterior(
642649 Defaults to True.
643650 combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
644651 Defaults to True.
652+ **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
645653
646654 Returns
647655 -------
@@ -651,7 +659,7 @@ def predict_posterior(
651659
652660 X_pred = self ._validate_data (X_pred )
653661 posterior_predictive_samples = self .sample_posterior_predictive (
654- X_pred , extend_idata , combined
662+ X_pred , extend_idata , combined , ** kwargs
655663 )
656664
657665 if self .output_var not in posterior_predictive_samples :
0 commit comments