Skip to content

Commit 7afe692

Browse files
authored
Merge branch 'dev' into dev
2 parents 3777601 + c434607 commit 7afe692

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

monai/networks/layers/filtering.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def __init__(self, spatial_sigma, color_sigma):
221221
self.len_spatial_sigma = 3
222222
else:
223223
raise ValueError(
224-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
224+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)"
225+
f"or be a single float value ({spatial_sigma=})."
225226
)
226227

227228
# Register sigmas as trainable parameters.
@@ -231,6 +232,10 @@ def __init__(self, spatial_sigma, color_sigma):
231232
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
232233

233234
def forward(self, input_tensor):
235+
if len(input_tensor.shape) < 3:
236+
raise ValueError(
237+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
238+
)
234239
if input_tensor.shape[1] != 1:
235240
raise ValueError(
236241
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
@@ -239,24 +244,27 @@ def forward(self, input_tensor):
239244
)
240245

241246
len_input = len(input_tensor.shape)
247+
spatial_dims = len_input - 2
242248

243249
# C++ extension so far only supports 5-dim inputs.
244-
if len_input == 3:
250+
if spatial_dims == 1:
245251
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
246-
elif len_input == 4:
252+
elif spatial_dims == 2:
247253
input_tensor = input_tensor.unsqueeze(4)
248254

249-
if self.len_spatial_sigma != len_input:
250-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
255+
if self.len_spatial_sigma != spatial_dims:
256+
raise ValueError(
257+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
258+
)
251259

252260
prediction = TrainableBilateralFilterFunction.apply(
253261
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
254262
)
255263

256264
# Make sure to return tensor of the same shape as the input.
257-
if len_input == 3:
265+
if spatial_dims == 1:
258266
prediction = prediction.squeeze(4).squeeze(3)
259-
elif len_input == 4:
267+
elif spatial_dims == 2:
260268
prediction = prediction.squeeze(4)
261269

262270
return prediction
@@ -389,7 +397,8 @@ def __init__(self, spatial_sigma, color_sigma):
389397
self.len_spatial_sigma = 3
390398
else:
391399
raise ValueError(
392-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
400+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n"
401+
f"or be a single float value ({spatial_sigma=})."
393402
)
394403

395404
# Register sigmas as trainable parameters.
@@ -399,39 +408,45 @@ def __init__(self, spatial_sigma, color_sigma):
399408
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
400409

401410
def forward(self, input_tensor, guidance_tensor):
411+
if len(input_tensor.shape) < 3:
412+
raise ValueError(
413+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
414+
)
402415
if input_tensor.shape[1] != 1:
403416
raise ValueError(
404-
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
417+
f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. "
405418
"Please use multiple parallel filter layers if you want "
406419
"to filter multiple channels."
407420
)
408421
if input_tensor.shape != guidance_tensor.shape:
409422
raise ValueError(
410-
"Shape of input image must equal shape of guidance image."
411-
f"Got {input_tensor.shape} and {guidance_tensor.shape}."
423+
f"Shape of input image must equal shape of guidance image, got {input_tensor.shape} and {guidance_tensor.shape}."
412424
)
413425

414426
len_input = len(input_tensor.shape)
427+
spatial_dims = len_input - 2
415428

416429
# C++ extension so far only supports 5-dim inputs.
417-
if len_input == 3:
430+
if spatial_dims == 1:
418431
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
419432
guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
420-
elif len_input == 4:
433+
elif spatial_dims == 2:
421434
input_tensor = input_tensor.unsqueeze(4)
422435
guidance_tensor = guidance_tensor.unsqueeze(4)
423436

424-
if self.len_spatial_sigma != len_input:
425-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
437+
if self.len_spatial_sigma != spatial_dims:
438+
raise ValueError(
439+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
440+
)
426441

427442
prediction = TrainableJointBilateralFilterFunction.apply(
428443
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
429444
)
430445

431446
# Make sure to return tensor of the same shape as the input.
432-
if len_input == 3:
447+
if spatial_dims == 1:
433448
prediction = prediction.squeeze(4).squeeze(3)
434-
elif len_input == 4:
449+
elif spatial_dims == 2:
435450
prediction = prediction.squeeze(4)
436451

437452
return prediction

0 commit comments

Comments
 (0)