Conversation
|
Will review in a bit. |
sayakpaul
left a comment
There was a problem hiding this comment.
Looking really promising. I left some comments, LMK if they make sense.
Additionally, if we could wrap the loss computations for the different phases into different functions, I think that will be easier to read. LMK what you think.
| @@ -0,0 +1,1823 @@ | |||
| #!/usr/bin/env python | |||
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | |||
There was a problem hiding this comment.
Feel free to add SANA Sprint team here too :)
| if is_torch_npu_available(): | ||
| torch.npu.config.allow_internal_format = False | ||
|
|
||
| complex_human_instruction = [ |
There was a problem hiding this comment.
| complex_human_instruction = [ | |
| COMPLEX_HUMAN_INSTRUCTION = [ |
| return False | ||
|
|
||
|
|
||
| class Text2ImageDataset: |
There was a problem hiding this comment.
Do we have an example dataset with which it would work?
| ) | ||
| # add meta-data to dataloader instance for convenience | ||
| self._train_dataloader.num_batches = num_batches | ||
| self._train_dataloader.num_samples = num_samples |
There was a problem hiding this comment.
Could use num_train_examples here no?
| disc.eval() | ||
| models_to_accumulate = [transformer] | ||
| with accelerator.accumulate(models_to_accumulate): | ||
| with torch.no_grad(): |
There was a problem hiding this comment.
We can then remove this context manager.
| images = None | ||
| del pipeline | ||
|
|
||
| # Save the lora layers |
There was a problem hiding this comment.
We are not doing LoRA. So, this can be safely omitted.
| cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | ||
| cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | ||
|
|
||
| cfg_pretrain_pred = pretrained_model( |
There was a problem hiding this comment.
As another optimization, we could keep the pretrained_model in CPU once this computation is done and load to GPU again when needed.
| phase = "G" | ||
|
|
||
| optimizer_D.step() | ||
| optimizer_D.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
I think set_to_none is by default True.
| lr_scheduler.step() | ||
| optimizer_G.zero_grad(set_to_none=True) | ||
|
|
||
| elif phase == "D": |
There was a problem hiding this comment.
So this alternates between two phases in the same training step, right? If so, I would add a comment.
Also, should we let the users control the step interval in which the discriminator should be updated? Or not really?
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
|
Thanks for your thorough review and helpful suggestions! I'll carefully go through them and incorporate the changes when I'm back. Really appreciate it! |
|
Please don't hesitate to ping me for running tests, etc. |
|
Adding here: |
|
Let's duplicate the files in diffusers here. @scxue |
Initial implementation of SANA-Sprint training script adapted for Diffusers.
This needs further refinement and optimization. @lawrence-cj @sayakpaul