Skip to content

Commit f502cc5

Browse files
committed
update version
Signed-off-by: Can-Zhao <[email protected]>
1 parent 0b101ab commit f502cc5

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

generation/maisi/scripts/download_model_data.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def fetch_to_hf_path_cmd(
8787

8888
def download_model_data(generate_version, root_dir, model_only=False):
8989
# TODO: remove the `files` after the files are uploaded to the NGC
90-
if generate_version == "ddpm-ct" or generate_version == "rflow-ct":
90+
if generate_version == 'maisi3d-ddpm' or generate_version == 'maisi3d-rflow':
9191
files = [
9292
{
9393
"path": "models/autoencoder_v1.pt",
@@ -118,24 +118,11 @@ def download_model_data(generate_version, root_dir, model_only=False):
118118
"filename": "datasets/all_masks_flexible_size_and_spacing_4000.zip",
119119
},
120120
]
121-
elif generate_version == "rflow-mr":
122-
files = [
123-
{
124-
"path": "models/autoencoder_v2.pt",
125-
"repo_id": "nvidia/NV-Generate-MR",
126-
"filename": "models/autoencoder_v2.pt",
127-
},
128-
{
129-
"path": "models/diff_unet_3d_rflow-mr.pt",
130-
"repo_id": "nvidia/NV-Generate-MR",
131-
"filename": "models/diff_unet_3d_rflow-mr.pt",
132-
},
133-
]
134121
else:
135122
raise ValueError(
136-
f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}."
123+
f"generate_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {generate_version}."
137124
)
138-
if generate_version == "ddpm-ct":
125+
if generate_version == 'maisi3d-ddpm':
139126
files += [
140127
{
141128
"path": "models/diff_unet_3d_ddpm-ct.pt",
@@ -156,7 +143,7 @@ def download_model_data(generate_version, root_dir, model_only=False):
156143
"filename": "datasets/candidate_masks_flexible_size_and_spacing_3000.json",
157144
},
158145
]
159-
elif generate_version == "rflow-ct":
146+
elif generate_version == 'maisi3d-rflow':
160147
files += [
161148
{
162149
"path": "models/diff_unet_3d_rflow-ct.pt",
@@ -193,7 +180,7 @@ def download_model_data(generate_version, root_dir, model_only=False):
193180
parser.add_argument(
194181
"--version",
195182
type=str,
196-
default="rflow-ct",
183+
default='maisi3d-rflow',
197184
)
198185
parser.add_argument(
199186
"--root_dir",

0 commit comments

Comments
 (0)