@@ -87,7 +87,7 @@ def fetch_to_hf_path_cmd(
8787
8888def download_model_data (generate_version , root_dir , model_only = False ):
8989 # TODO: remove the `files` after the files are uploaded to the NGC
90- if generate_version == "ddpm-ct" or generate_version == "rflow-ct" :
90+ if generate_version == 'maisi3d-ddpm' or generate_version == 'maisi3d-rflow' :
9191 files = [
9292 {
9393 "path" : "models/autoencoder_v1.pt" ,
@@ -118,24 +118,11 @@ def download_model_data(generate_version, root_dir, model_only=False):
118118 "filename" : "datasets/all_masks_flexible_size_and_spacing_4000.zip" ,
119119 },
120120 ]
121- elif generate_version == "rflow-mr" :
122- files = [
123- {
124- "path" : "models/autoencoder_v2.pt" ,
125- "repo_id" : "nvidia/NV-Generate-MR" ,
126- "filename" : "models/autoencoder_v2.pt" ,
127- },
128- {
129- "path" : "models/diff_unet_3d_rflow-mr.pt" ,
130- "repo_id" : "nvidia/NV-Generate-MR" ,
131- "filename" : "models/diff_unet_3d_rflow-mr.pt" ,
132- },
133- ]
134121 else :
135122 raise ValueError (
136- f"generate_version has to be chosen from ['ddpm-ct ', 'rflow-ct', ' rflow-mr '], yet got { generate_version } ."
123+ f"generate_version has to be chosen from ['maisi3d-ddpm ', 'maisi3d- rflow'], yet got { generate_version } ."
137124 )
138- if generate_version == "ddpm-ct" :
125+ if generate_version == 'maisi3d-ddpm' :
139126 files += [
140127 {
141128 "path" : "models/diff_unet_3d_ddpm-ct.pt" ,
@@ -156,7 +143,7 @@ def download_model_data(generate_version, root_dir, model_only=False):
156143 "filename" : "datasets/candidate_masks_flexible_size_and_spacing_3000.json" ,
157144 },
158145 ]
159- elif generate_version == "rflow-ct" :
146+ elif generate_version == 'maisi3d-rflow' :
160147 files += [
161148 {
162149 "path" : "models/diff_unet_3d_rflow-ct.pt" ,
@@ -193,7 +180,7 @@ def download_model_data(generate_version, root_dir, model_only=False):
193180 parser .add_argument (
194181 "--version" ,
195182 type = str ,
196- default = "rflow-ct" ,
183+ default = 'maisi3d-rflow' ,
197184 )
198185 parser .add_argument (
199186 "--root_dir" ,
0 commit comments