Skip to content

Commit e424767

Browse files
committed
Improve segmentation visualization, typing and docstrings
1 parent a2d3fff commit e424767

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

BraTS/utils.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from pathlib import Path
2+
from typing import Union
3+
24
import matplotlib.pyplot as plt
35
import nibabel as nib
6+
import numpy as np
47

5-
DATA_FOLDER = "data"
8+
DATA_FOLDER = Path("data")
69

710

811
def visualize_segmentation_data(
9-
data_folder: str = DATA_FOLDER,
12+
data_folder: Union[str, Path] = DATA_FOLDER,
1013
subject_id: str = "BraTS-GLI-00001-000",
1114
slice_index: int = 75,
1215
):
1316
"""Visualize the MRI modalities for a given slice index
1417
1518
Args:
16-
data_folder (str, optional): Path to the folder containing the t1, t1c, t2 & flair file. Defaults to DATA_FOLDER.
19+
data_folder (Union[str, Path]): Path to the folder containing the t1, t1c, t2 & flair file. Defaults to DATA_FOLDER.
1720
slice_index (int, optional): Slice to be visualized (first index in data of shape (155, 240, 240)). Defaults to 75.
1821
"""
1922
_, axes = plt.subplots(1, 4, figsize=(12, 10))
@@ -29,14 +32,14 @@ def visualize_segmentation_data(
2932

3033

3134
def visualize_inpainting_data(
32-
data_folder: str = DATA_FOLDER,
35+
data_folder: Union[str, Path] = DATA_FOLDER,
3336
subject_id: str = "BraTS-GLI-00001-000",
3437
slice_index: int = 75,
3538
):
3639
"""Visualize the MRI modalities for a given slice index
3740
3841
Args:
39-
data_folder (str, optional): Path to the folder containing the t1n and mask files. Defaults to DATA_FOLDER.
42+
data_folder (Union[str, Path]): Path to the folder containing the t1n and mask files. Defaults to DATA_FOLDER.
4043
slice_index (int, optional): Slice to be visualized (first index in data of shape (155, 240, 240)). Defaults to 75.
4144
"""
4245
_, axes = plt.subplots(1, 2, figsize=(6, 10))
@@ -51,32 +54,39 @@ def visualize_inpainting_data(
5154
axes[i].axis("off")
5255

5356

54-
def visualize_segmentation(modality_file: str, segmentation_file: str):
57+
def visualize_segmentation(
58+
modality_file: Union[str, Path], segmentation_file: Union[str, Path]
59+
):
5560
"""Visualize the MRI modality and the segmentation
5661
5762
Args:
58-
modality_file (str): Path to the desired modality file
59-
segmentation_file (str): Path to the segmentation file
63+
modality_file (Union[str, Path]): Path to the desired modality file
64+
segmentation_file (Union[str, Path]): Path to the segmentation file
6065
"""
6166
modality_np = nib.load(modality_file).get_fdata().transpose(2, 1, 0)
6267
seg_np = nib.load(segmentation_file).get_fdata().transpose(2, 1, 0)
68+
6369
_, ax = plt.subplots(1, 2, figsize=(8, 4))
6470

6571
slice_index = modality_np.shape[0] // 2 # You can choose any slice here
72+
73+
# Mask out background (0) in the segmentation
74+
seg_slice = seg_np[slice_index, :, :]
6675
ax[0].imshow(modality_np[slice_index, :, :], cmap="gray")
6776
ax[1].imshow(modality_np[slice_index, :, :], cmap="gray")
68-
ax[1].imshow(seg_np[slice_index, :, :], cmap="plasma", alpha=0.3)
77+
ax[1].imshow(seg_slice, cmap="plasma", alpha=np.where(seg_slice > 0, 0.3, 0))
78+
6979
for ax in ax:
7080
ax.axis("off")
7181
plt.tight_layout()
7282

7383

74-
def visualize_inpainting(t1n_voided: str, prediction: str):
84+
def visualize_inpainting(t1n_voided: Union[str, Path], prediction: Union[str, Path]):
7585
"""Visualize the inpainting results
7686
7787
Args:
78-
t1n_voided (str): Voided T1 modality file
79-
prediction (str): Inpainting prediction file
88+
t1n_voided (Union[str, Path]): Voided T1 modality file
89+
prediction (Union[str, Path]): Inpainting prediction file
8090
"""
8191
voided_np = nib.load(t1n_voided).get_fdata().transpose(2, 1, 0)
8292
inpainting_np = nib.load(prediction).get_fdata().transpose(2, 1, 0)
@@ -91,15 +101,17 @@ def visualize_inpainting(t1n_voided: str, prediction: str):
91101

92102

93103
def visualize_missing_mri_t2w(
94-
synthesized_t2w: str,
95-
data_folder: str = DATA_FOLDER,
104+
synthesized_t2w: Union[str, Path],
105+
data_folder: Union[str, Path] = DATA_FOLDER,
96106
subject_id: str = "BraTS-GLI-00001-000",
97107
slice_index: int = 75,
98108
):
99109
"""Visualize the MRI modalities for a given slice index
100110
101111
Args:
102-
data_folder (str, optional): Path to the folder containing the t1, t1c, t2 & flair file. Defaults to DATA_FOLDER.
112+
synthesized_t2w (Union[str, Path]): Path to the synthesized T2w file
113+
data_folder (Union[str, Path], optional): Path to the folder containing the t1, t1c, t2 & flair file. Defaults to DATA_FOLDER.
114+
subject_id (str, optional): Subject ID to visualize. Defaults to "BraTS-GLI-00001-000".
103115
slice_index (int, optional): Slice to be visualized (first index in data of shape (155, 240, 240)). Defaults to 75.
104116
"""
105117
_, axes = plt.subplots(1, 5, figsize=(12, 10))

0 commit comments

Comments
 (0)