Skip to content

Commit 303f541

Browse files
author
Mohit Soni
committed
Comments Addressed
Signed-off-by: Mohit Soni <mohisoni@qti.qualcomm.com>
1 parent 0856180 commit 303f541

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,13 @@ def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
302302
- output_names (List[str]): Names of model outputs
303303
"""
304304
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
305+
num_frames = constants.WAN_ONNX_EXPORT_LATENT_FRAMES
306+
latent_height = constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P
307+
latent_width = constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P
305308

306309
# VAE decoder takes latent representation as input
307310
example_inputs = {
308-
"latent_sample": torch.randn(bs, 16, 21, 12, 16),
311+
"latent_sample": torch.randn(bs, 16, num_frames, latent_height, latent_width),
309312
"return_dict": False,
310313
}
311314

@@ -339,6 +342,8 @@ def export(
339342
Returns:
340343
str: Path to the exported ONNX model
341344
"""
345+
self.model.config["_use_default_values"].sort()
346+
342347
return self._export(
343348
example_inputs=inputs,
344349
output_names=output_names,

QEfficient/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
import torch
2323
from diffusers import WanPipeline
24+
from tqdm import tqdm
2425

2526
from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer
2627
from QEfficient.diffusers.pipelines.pipeline_utils import (
@@ -121,7 +122,6 @@ def __init__(self, model, **kwargs):
121122
)
122123

123124
self.vae_decoder.get_onnx_params = self.vae_decoder.get_video_onnx_params
124-
self.vae_decoder.model.config["_use_default_values"].sort()
125125
# Extract patch dimensions from transformer configuration
126126
_, self.patch_height, self.patch_width = self.transformer.model.config.patch_size
127127

@@ -227,7 +227,7 @@ def export(
227227
"""
228228

229229
# Export each module with video-specific parameters
230-
for module_name, module_obj in self.modules.items():
230+
for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"):
231231
# Get ONNX export configuration with video dimensions
232232
example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params()
233233

@@ -308,6 +308,7 @@ def compile(
308308
path is None
309309
for path in [
310310
self.transformer.onnx_path,
311+
self.vae_decoder.onnx_path,
311312
]
312313
):
313314
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
@@ -343,13 +344,11 @@ def compile(
343344
"num_frames": latent_frames, # Latent frames
344345
},
345346
],
346-
"vae_decoder": [
347-
{
348-
"num_frames": latent_frames,
349-
"latent_height": latent_height,
350-
"latent_width": latent_width,
351-
}
352-
],
347+
"vae_decoder": {
348+
"num_frames": latent_frames,
349+
"latent_height": latent_height,
350+
"latent_width": latent_width,
351+
},
353352
}
354353

355354
# Use generic utility functions for compilation

examples/diffusers/wan/wan_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
},
3636
"vae_decoder":
3737
{
38-
"specializations": [
38+
"specializations":
3939
{
4040
"batch_size": 1,
4141
"num_channels": 16
4242
}
43-
],
43+
,
4444
"compilation":
4545
{
4646
"onnx_path": null,

0 commit comments

Comments
 (0)