99
1010from muograph .utils .device import DEVICE
1111from muograph .tracking .tracking import TrackingMST
12- from muograph .reconstruction .poca import POCA
12+ from muograph .reconstruction .poca import POCA , POCAParams
1313from muograph .volume .volume import Volume
1414from muograph .reconstruction .voxel_inferer import AbsVoxelInferer
1515
2424
2525
2626@dataclass
27- class BCAParams :
27+ class BCAParams ( POCAParams ) :
2828 n_max_per_vox : int = 50
2929 n_min_per_vox : int = 3
3030 score_method : partial = partial (torch .quantile , q = 0.5 )
3131 metric_method : partial = partial (torch .log )
3232 p_range : Tuple [float , float ] = (0.0 , 10000000 )
3333 dtheta_range : Tuple [float , float ] = (0.0 , math .pi / 3 )
34- use_p : bool = False
35- p_clamp : float = 0.999
36- dtheta_clamp : float = 0.999
3734
3835
39- class BCA (POCA , AbsVoxelInferer ):
40- _bca_params : BCAParams = BCAParams ()
36+ class BCA (POCA [ BCAParams ] , AbsVoxelInferer ):
37+ _params : BCAParams = BCAParams ()
4138
4239 _vars_to_save = ["xyz_voxel_pred" , "n_poca_per_vox" ]
4340
@@ -386,7 +383,7 @@ def get_xyz_voxel_pred(self) -> Tensor:
386383
387384 Compute voxel-wise scattering density predictions.
388385
389- Uses parameters stored in `_bca_params `. The algorithm calculates:
386+ Uses parameters stored in `_params `. The algorithm calculates:
390387 - Scattering density predictions (`pred`).
391388 - Number of POCA points used for each voxel's prediction (`hit_per_voxel`).
392389
@@ -405,41 +402,41 @@ def get_xyz_voxel_pred(self) -> Tensor:
405402 self .nhit ,
406403 self .nhit_rejected ,
407404 ) = self .compute_low_theta_events_voxel_wise_mask (
408- n_max_per_voxel = int (self .bca_params .n_max_per_vox ), # type: ignore
405+ n_max_per_voxel = int (self .params .n_max_per_vox ), # type: ignore
409406 voi = self .voi ,
410407 bca_indices = self .bca_indices ,
411408 dtheta = self .bca_tracks .dtheta ,
412409 )
413410 self ._filter_events (self .mask )
414411
415412 # momentum cut
416- if self .bca_params .use_p :
417- p_mask = (self .bca_tracks .E > self .bca_params .p_range [0 ]) & ( # type: ignore
418- self .bca_tracks .E < self .bca_params .p_range [1 ] # type: ignore
413+ if self .params .use_p :
414+ p_mask = (self .bca_tracks .E > self .params .p_range [0 ]) & ( # type: ignore
415+ self .bca_tracks .E < self .params .p_range [1 ] # type: ignore
419416 )
420417 else :
421418 p_mask = torch .ones_like (self .bca_tracks .dtheta , dtype = torch .bool , device = DEVICE )
422419
423420 # scattering angle cut
424- dtheta_mask = (self .bca_tracks .dtheta > self .bca_params .dtheta_range [0 ]) & ( # type: ignore
425- self .bca_tracks .dtheta < self .bca_params .dtheta_range [1 ] # type: ignore
421+ dtheta_mask = (self .bca_tracks .dtheta > self .params .dtheta_range [0 ]) & ( # type: ignore
422+ self .bca_tracks .dtheta < self .params .dtheta_range [1 ] # type: ignore
426423 )
427424
428425 # apply dtheta, p cuts
429426 self ._filter_events (mask = p_mask & dtheta_mask )
430427
431428 # prepare variables
432- p_max = torch .quantile (self .bca_tracks .p , q = self .bca_params .p_clamp )
429+ p_max = torch .quantile (self .bca_tracks .p , q = self .params .p_clamp )
433430 p = torch .clamp (self .bca_tracks .p , max = p_max )
434431
435- dtheta_max = torch .quantile (self .bca_tracks .dtheta , q = self .bca_params .dtheta_clamp )
432+ dtheta_max = torch .quantile (self .bca_tracks .dtheta , q = self .params .dtheta_clamp )
436433 dtheta = torch .clamp (self .bca_tracks .dtheta , max = dtheta_max )
437434
438435 # compute voxels distribution
439436 self .score_list = self .compute_voxels_distribution (
440- metric_method = self .bca_params .metric_method , # type: ignore
441- use_p = self .bca_params .use_p , # type: ignore
442- n_min_per_vox = self .bca_params .n_min_per_vox , # type: ignore
437+ metric_method = self .params .metric_method , # type: ignore
438+ use_p = self .params .use_p , # type: ignore
439+ n_min_per_vox = self .params .n_min_per_vox , # type: ignore
443440 voi = self .voi ,
444441 momentum = p ,
445442 bca_indices = self .bca_indices ,
@@ -448,7 +445,10 @@ def get_xyz_voxel_pred(self) -> Tensor:
448445 )
449446
450447 # compute fina scores
451- pred , self ._hit_per_voxel = self .compute_final_scores (score_list = self .score_list , score_method = self .bca_params .score_method ) # type: ignore
448+ pred , self ._hit_per_voxel = self .compute_final_scores (score_list = self .score_list , score_method = self .params .score_method ) # type: ignore
449+
450+ pred_max = torch .quantile (pred , q = self .params .preds_clamp )
451+ pred = torch .clamp (pred , max = pred_max )
452452
453453 self ._recompute_preds = False
454454
@@ -471,16 +471,17 @@ def get_partial_name_args(func: partial) -> str:
471471 func_name += "_{}={}" .format (arg , values [i ])
472472 return func_name
473473
474- method = "method_{}_" .format (get_partial_name_args (self .bca_params .score_method )) # type: ignore
475- metric = "metric_{}_" .format (get_partial_name_args (self .bca_params .metric_method )) # type: ignore
476- dtheta = "{:.1f}_{:.1f}_mrad_" .format (self .bca_params .dtheta_range [0 ] * 1000 , self .bca_params .dtheta_range [1 ] * 1000 ) # type: ignore
477- dp = "{:.0f}_{:.0f}_MeV_" .format (self .bca_params .p_range [0 ], self .bca_params .p_range [1 ]) # type: ignore
478- n_min_max = "n_min_max_{}_{}_" .format (self .bca_params .n_min_per_vox , self .bca_params .n_max_per_vox )
479- use_p = "use_p_{}" .format (self .bca_params .use_p )
480- p_clamp = "_pclamp_{:.3f}" .format (self .bca_params .p_clamp ) if self .bca_params .use_p else ""
481- dtheta_clamp = "_dthetaclamp_{:.3f}" .format (self .bca_params .dtheta_clamp )
482-
483- bca_name = method + metric + dtheta + dp + n_min_max + use_p + p_clamp + dtheta_clamp
474+ method = "method_{}_" .format (get_partial_name_args (self .params .score_method )) # type: ignore
475+ metric = "metric_{}_" .format (get_partial_name_args (self .params .metric_method )) # type: ignore
476+ dtheta = "{:.1f}_{:.1f}_mrad_" .format (self .params .dtheta_range [0 ] * 1000 , self .params .dtheta_range [1 ] * 1000 ) # type: ignore
477+ dp = "{:.0f}_{:.0f}_MeV_" .format (self .params .p_range [0 ], self .params .p_range [1 ]) # type: ignore
478+ n_min_max = "n_min_max_{}_{}_" .format (self .params .n_min_per_vox , self .params .n_max_per_vox )
479+ use_p = "use_p_{}" .format (self .params .use_p )
480+ p_clamp = "_pclamp_{:.3f}" .format (self .params .p_clamp ) if self .params .use_p else ""
481+ dtheta_clamp = "_dthetaclamp_{:.3f}" .format (self .params .dtheta_clamp )
482+ preds_clamp = "_preds_clamp_{:.3f}" .format (self .params .preds_clamp )
483+
484+ bca_name = method + metric + dtheta + dp + n_min_max + use_p + p_clamp + dtheta_clamp + preds_clamp
484485 bca_name = bca_name .replace ("." , "p" )
485486 return bca_name
486487
@@ -492,29 +493,29 @@ def name(self) -> str:
492493 return self .get_bca_name ()
493494
494495 @property
495- def bca_params (self ) -> BCAParams :
496+ def params (self ) -> BCAParams :
496497 r"""
497498 The parameters of the bca algorithm.
498499 """
499- return self ._bca_params
500+ return self ._params
500501
501- @bca_params .setter
502- def bca_params (self , value : BCAParams ) -> None :
502+ @params .setter
503+ def params (self , value : BCAParams ) -> None :
503504 r"""
504505 Sets the parameters of the bca algorithm.
505506 Args:
506507 - Dict containing the parameters name and value. Only parameters with
507508 valid name and non `None` values wil be updated.
508509 """
509510 if not isinstance (value , BCAParams ):
510- raise TypeError ("bca_params must be an instance of BCAParams" )
511+ raise TypeError ("params must be an instance of BCAParams" )
511512
512- if not hasattr (self , "_bca_params " ) or self ._bca_params is None :
513- self ._bca_params = BCAParams ()
513+ if not hasattr (self , "_params " ) or self ._params is None :
514+ self ._params = BCAParams ()
514515
515516 for key , val in value .__dict__ .items ():
516517 if val is not None :
517- setattr (self ._bca_params , key , val )
518+ setattr (self ._params , key , val )
518519
519520 self ._recompute_preds = True
520521
0 commit comments