Skip to content

Commit e8fe984

Browse files
unified algorithms params
1 parent dbd26bc commit e8fe984

File tree

9 files changed

+226
-202
lines changed

9 files changed

+226
-202
lines changed

muograph/reconstruction/asr.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ASR(AbsSave, AbsVoxelInferer):
4242
_n_mu_per_vox: Optional[Tensor] = None # (Nx, Ny, Nz)
4343
_recompute_preds = True
4444

45-
_asr_params: ASRParams = ASRParams()
45+
_params: ASRParams = ASRParams()
4646

4747
_vars_to_save = ["triggered_voxels"]
4848

@@ -345,7 +345,7 @@ def _find_triggered_voxels(
345345

346346
@staticmethod
347347
def get_name_from_params(
348-
asr_params: ASRParams,
348+
params: ASRParams,
349349
) -> str:
350350
r"""
351351
Returns a string representing the ASR configuration based on its parameters.
@@ -361,12 +361,12 @@ def get_partial_name_args(func: partial) -> str:
361361
func_name += "_{}={}".format(arg, values[i])
362362
return func_name
363363

364-
method = "method_{}_".format(get_partial_name_args(asr_params.score_method)) # type: ignore
365-
dtheta = "{:.2f}_{:.2f}_rad_".format(asr_params.dtheta_range[0], asr_params.dtheta_range[1]) # type: ignore
366-
dp = "{:.0f}_{:.0f}_MeV_".format(asr_params.p_range[0], asr_params.p_range[1]) # type: ignore
367-
use_p = "use_p_{}".format(asr_params.use_p)
368-
p_clamp = "_pclamp_{:.3f}".format(asr_params.p_clamp) if asr_params.use_p else ""
369-
dtheta_clamp = "_dthetaclamp_{:.3f}".format(asr_params.dtheta_clamp)
364+
method = "method_{}_".format(get_partial_name_args(params.score_method)) # type: ignore
365+
dtheta = "{:.2f}_{:.2f}_rad_".format(params.dtheta_range[0], params.dtheta_range[1]) # type: ignore
366+
dp = "{:.0f}_{:.0f}_MeV_".format(params.p_range[0], params.p_range[1]) # type: ignore
367+
use_p = "use_p_{}".format(params.use_p)
368+
p_clamp = "_pclamp_{:.3f}".format(params.p_clamp) if params.use_p else ""
369+
dtheta_clamp = "_dthetaclamp_{:.3f}".format(params.dtheta_clamp)
370370

371371
asr_name = method + dtheta + dp + use_p + p_clamp + dtheta_clamp
372372
asr_name = asr_name.replace(".", "p")
@@ -432,11 +432,11 @@ def get_xyz_voxel_pred(self) -> Tensor:
432432
[[[] for _ in range(self.voi.n_vox_xyz[2])] for _ in range(self.voi.n_vox_xyz[1])] for _ in range(self.voi.n_vox_xyz[0])
433433
]
434434

435-
dtheta_max = torch.quantile(self.tracks.dtheta, q=self.asr_params.dtheta_clamp)
435+
dtheta_max = torch.quantile(self.tracks.dtheta, q=self.params.dtheta_clamp)
436436
dtheta = torch.clamp(self.tracks.dtheta, max=dtheta_max)
437437

438-
if self._asr_params.use_p:
439-
p_max = torch.quantile(self.tracks.p, q=self.asr_params.p_clamp)
438+
if self._params.use_p:
439+
p_max = torch.quantile(self.tracks.p, q=self.params.p_clamp)
440440
p = torch.clamp(self.tracks.p, max=p_max)
441441
score = np.log(0.0000001 + dtheta.detach().cpu().numpy() * p.detach().cpu().numpy())
442442
else:
@@ -445,14 +445,14 @@ def get_xyz_voxel_pred(self) -> Tensor:
445445
# score = self.theta_xy_in[0].detach().cpu().numpy()
446446
# score = self.tracks.theta_in.detach().cpu().numpy()
447447

448-
mask_E = (self.tracks.E >= self.asr_params.p_range[0]) & ( # type: ignore
449-
self.tracks.E <= self.asr_params.p_range[1] # type: ignore
448+
mask_E = (self.tracks.E >= self.params.p_range[0]) & ( # type: ignore
449+
self.tracks.E <= self.params.p_range[1] # type: ignore
450450
)
451-
mask_theta = (self.tracks.dtheta >= self.asr_params.dtheta_range[0]) & ( # type: ignore
452-
self.tracks.dtheta <= self.asr_params.dtheta_range[1] # type: ignore
451+
mask_theta = (self.tracks.dtheta >= self.params.dtheta_range[0]) & ( # type: ignore
452+
self.tracks.dtheta <= self.params.dtheta_range[1] # type: ignore
453453
)
454454

455-
if self.asr_params.use_p: # type: ignore
455+
if self.params.use_p: # type: ignore
456456
mask = mask_E & mask_theta
457457
else:
458458
mask = mask_E
@@ -472,14 +472,14 @@ def get_xyz_voxel_pred(self) -> Tensor:
472472
for j in range(self.voi.n_vox_xyz[1]):
473473
for k in range(self.voi.n_vox_xyz[2]):
474474
if score_list[i][j][k] != []:
475-
vox_density_preds[i, j, k] = self.asr_params.score_method(score_list[i][j][k]) # type: ignore
475+
vox_density_preds[i, j, k] = self.params.score_method(score_list[i][j][k]) # type: ignore
476476
self.n_mu_per_vox_test[i, j, k] = len(score_list[i][j][k])
477477
if vox_density_preds.isnan().any():
478478
raise ValueError("Prediction contains NaN values")
479479
self.score_list = score_list
480480
self._recompute_preds = False
481481

482-
if self.asr_params.use_p:
482+
if self.params.use_p:
483483
return torch.exp(vox_density_preds)
484484
else:
485485
return vox_density_preds
@@ -599,29 +599,29 @@ def theta_xy_out(self) -> Tuple[Tensor, Tensor]:
599599
return (self.tracks.theta_xy_out[0], self.tracks.theta_xy_out[1])
600600

601601
@property
602-
def asr_params(self) -> ASRParams:
602+
def params(self) -> ASRParams:
603603
r"""
604604
The parameters of the ASR algorithm.
605605
"""
606-
return self._asr_params
606+
return self._params
607607

608-
@asr_params.setter
609-
def asr_params(self, value: ASRParams) -> None:
608+
@params.setter
609+
def params(self, value: ASRParams) -> None:
610610
r"""
611611
Sets the parameters of the ASR algorithm.
612612
Args:
613613
- Dict containing the parameters name and value. Only parameters with
614614
valid name and non `None` values will be updated.
615615
"""
616616
if not isinstance(value, ASRParams):
617-
raise TypeError("asr_params must be an instance of ASRParams")
617+
raise TypeError("params must be an instance of ASRParams")
618618

619-
if not hasattr(self, "_asr_params") or self._asr_params is None:
620-
self._asr_params = ASRParams()
619+
if not hasattr(self, "_params") or self._params is None:
620+
self._params = ASRParams()
621621

622622
for key, val in value.__dict__.items():
623623
if val is not None:
624-
setattr(self._asr_params, key, val)
624+
setattr(self._params, key, val)
625625

626626
self._recompute_preds = True
627627

@@ -659,4 +659,4 @@ def name(self) -> str:
659659
r"""
660660
The name of the ASR configuration based on its parameters.
661661
"""
662-
return self.get_name_from_params(self.asr_params)
662+
return self.get_name_from_params(self.params)

muograph/reconstruction/binned_clustered.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from muograph.utils.device import DEVICE
1111
from muograph.tracking.tracking import TrackingMST
12-
from muograph.reconstruction.poca import POCA
12+
from muograph.reconstruction.poca import POCA, POCAParams
1313
from muograph.volume.volume import Volume
1414
from muograph.reconstruction.voxel_inferer import AbsVoxelInferer
1515

@@ -24,20 +24,17 @@
2424

2525

2626
@dataclass
27-
class BCAParams:
27+
class BCAParams(POCAParams):
2828
n_max_per_vox: int = 50
2929
n_min_per_vox: int = 3
3030
score_method: partial = partial(torch.quantile, q=0.5)
3131
metric_method: partial = partial(torch.log)
3232
p_range: Tuple[float, float] = (0.0, 10000000)
3333
dtheta_range: Tuple[float, float] = (0.0, math.pi / 3)
34-
use_p: bool = False
35-
p_clamp: float = 0.999
36-
dtheta_clamp: float = 0.999
3734

3835

39-
class BCA(POCA, AbsVoxelInferer):
40-
_bca_params: BCAParams = BCAParams()
36+
class BCA(POCA[BCAParams], AbsVoxelInferer):
37+
_params: BCAParams = BCAParams()
4138

4239
_vars_to_save = ["xyz_voxel_pred", "n_poca_per_vox"]
4340

@@ -386,7 +383,7 @@ def get_xyz_voxel_pred(self) -> Tensor:
386383
387384
Compute voxel-wise scattering density predictions.
388385
389-
Uses parameters stored in `_bca_params`. The algorithm calculates:
386+
Uses parameters stored in `_params`. The algorithm calculates:
390387
- Scattering density predictions (`pred`).
391388
- Number of POCA points used for each voxel's prediction (`hit_per_voxel`).
392389
@@ -405,41 +402,41 @@ def get_xyz_voxel_pred(self) -> Tensor:
405402
self.nhit,
406403
self.nhit_rejected,
407404
) = self.compute_low_theta_events_voxel_wise_mask(
408-
n_max_per_voxel=int(self.bca_params.n_max_per_vox), # type: ignore
405+
n_max_per_voxel=int(self.params.n_max_per_vox), # type: ignore
409406
voi=self.voi,
410407
bca_indices=self.bca_indices,
411408
dtheta=self.bca_tracks.dtheta,
412409
)
413410
self._filter_events(self.mask)
414411

415412
# momentum cut
416-
if self.bca_params.use_p:
417-
p_mask = (self.bca_tracks.E > self.bca_params.p_range[0]) & ( # type: ignore
418-
self.bca_tracks.E < self.bca_params.p_range[1] # type: ignore
413+
if self.params.use_p:
414+
p_mask = (self.bca_tracks.E > self.params.p_range[0]) & ( # type: ignore
415+
self.bca_tracks.E < self.params.p_range[1] # type: ignore
419416
)
420417
else:
421418
p_mask = torch.ones_like(self.bca_tracks.dtheta, dtype=torch.bool, device=DEVICE)
422419

423420
# scattering angle cut
424-
dtheta_mask = (self.bca_tracks.dtheta > self.bca_params.dtheta_range[0]) & ( # type: ignore
425-
self.bca_tracks.dtheta < self.bca_params.dtheta_range[1] # type: ignore
421+
dtheta_mask = (self.bca_tracks.dtheta > self.params.dtheta_range[0]) & ( # type: ignore
422+
self.bca_tracks.dtheta < self.params.dtheta_range[1] # type: ignore
426423
)
427424

428425
# apply dtheta, p cuts
429426
self._filter_events(mask=p_mask & dtheta_mask)
430427

431428
# prepare variables
432-
p_max = torch.quantile(self.bca_tracks.p, q=self.bca_params.p_clamp)
429+
p_max = torch.quantile(self.bca_tracks.p, q=self.params.p_clamp)
433430
p = torch.clamp(self.bca_tracks.p, max=p_max)
434431

435-
dtheta_max = torch.quantile(self.bca_tracks.dtheta, q=self.bca_params.dtheta_clamp)
432+
dtheta_max = torch.quantile(self.bca_tracks.dtheta, q=self.params.dtheta_clamp)
436433
dtheta = torch.clamp(self.bca_tracks.dtheta, max=dtheta_max)
437434

438435
# compute voxels distribution
439436
self.score_list = self.compute_voxels_distribution(
440-
metric_method=self.bca_params.metric_method, # type: ignore
441-
use_p=self.bca_params.use_p, # type: ignore
442-
n_min_per_vox=self.bca_params.n_min_per_vox, # type: ignore
437+
metric_method=self.params.metric_method, # type: ignore
438+
use_p=self.params.use_p, # type: ignore
439+
n_min_per_vox=self.params.n_min_per_vox, # type: ignore
443440
voi=self.voi,
444441
momentum=p,
445442
bca_indices=self.bca_indices,
@@ -448,7 +445,10 @@ def get_xyz_voxel_pred(self) -> Tensor:
448445
)
449446

450447
# compute fina scores
451-
pred, self._hit_per_voxel = self.compute_final_scores(score_list=self.score_list, score_method=self.bca_params.score_method) # type: ignore
448+
pred, self._hit_per_voxel = self.compute_final_scores(score_list=self.score_list, score_method=self.params.score_method) # type: ignore
449+
450+
pred_max = torch.quantile(pred, q=self.params.preds_clamp)
451+
pred = torch.clamp(pred, max=pred_max)
452452

453453
self._recompute_preds = False
454454

@@ -471,16 +471,17 @@ def get_partial_name_args(func: partial) -> str:
471471
func_name += "_{}={}".format(arg, values[i])
472472
return func_name
473473

474-
method = "method_{}_".format(get_partial_name_args(self.bca_params.score_method)) # type: ignore
475-
metric = "metric_{}_".format(get_partial_name_args(self.bca_params.metric_method)) # type: ignore
476-
dtheta = "{:.1f}_{:.1f}_mrad_".format(self.bca_params.dtheta_range[0] * 1000, self.bca_params.dtheta_range[1] * 1000) # type: ignore
477-
dp = "{:.0f}_{:.0f}_MeV_".format(self.bca_params.p_range[0], self.bca_params.p_range[1]) # type: ignore
478-
n_min_max = "n_min_max_{}_{}_".format(self.bca_params.n_min_per_vox, self.bca_params.n_max_per_vox)
479-
use_p = "use_p_{}".format(self.bca_params.use_p)
480-
p_clamp = "_pclamp_{:.3f}".format(self.bca_params.p_clamp) if self.bca_params.use_p else ""
481-
dtheta_clamp = "_dthetaclamp_{:.3f}".format(self.bca_params.dtheta_clamp)
482-
483-
bca_name = method + metric + dtheta + dp + n_min_max + use_p + p_clamp + dtheta_clamp
474+
method = "method_{}_".format(get_partial_name_args(self.params.score_method)) # type: ignore
475+
metric = "metric_{}_".format(get_partial_name_args(self.params.metric_method)) # type: ignore
476+
dtheta = "{:.1f}_{:.1f}_mrad_".format(self.params.dtheta_range[0] * 1000, self.params.dtheta_range[1] * 1000) # type: ignore
477+
dp = "{:.0f}_{:.0f}_MeV_".format(self.params.p_range[0], self.params.p_range[1]) # type: ignore
478+
n_min_max = "n_min_max_{}_{}_".format(self.params.n_min_per_vox, self.params.n_max_per_vox)
479+
use_p = "use_p_{}".format(self.params.use_p)
480+
p_clamp = "_pclamp_{:.3f}".format(self.params.p_clamp) if self.params.use_p else ""
481+
dtheta_clamp = "_dthetaclamp_{:.3f}".format(self.params.dtheta_clamp)
482+
preds_clamp = "_preds_clamp_{:.3f}".format(self.params.preds_clamp)
483+
484+
bca_name = method + metric + dtheta + dp + n_min_max + use_p + p_clamp + dtheta_clamp + preds_clamp
484485
bca_name = bca_name.replace(".", "p")
485486
return bca_name
486487

@@ -492,29 +493,29 @@ def name(self) -> str:
492493
return self.get_bca_name()
493494

494495
@property
495-
def bca_params(self) -> BCAParams:
496+
def params(self) -> BCAParams:
496497
r"""
497498
The parameters of the bca algorithm.
498499
"""
499-
return self._bca_params
500+
return self._params
500501

501-
@bca_params.setter
502-
def bca_params(self, value: BCAParams) -> None:
502+
@params.setter
503+
def params(self, value: BCAParams) -> None:
503504
r"""
504505
Sets the parameters of the bca algorithm.
505506
Args:
506507
- Dict containing the parameters name and value. Only parameters with
507508
valid name and non `None` values wil be updated.
508509
"""
509510
if not isinstance(value, BCAParams):
510-
raise TypeError("bca_params must be an instance of BCAParams")
511+
raise TypeError("params must be an instance of BCAParams")
511512

512-
if not hasattr(self, "_bca_params") or self._bca_params is None:
513-
self._bca_params = BCAParams()
513+
if not hasattr(self, "_params") or self._params is None:
514+
self._params = BCAParams()
514515

515516
for key, val in value.__dict__.items():
516517
if val is not None:
517-
setattr(self._bca_params, key, val)
518+
setattr(self._params, key, val)
518519

519520
self._recompute_preds = True
520521

0 commit comments

Comments
 (0)