-
Notifications
You must be signed in to change notification settings - Fork 248
Add Support for Nemotron diffusion #3105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
26ae811
ad835d3
2636e65
2602b76
2bc9efa
a5d7730
bbf3a50
6a229fb
0e4fabe
1ee7757
5bcd6a6
6d3c0c9
213e24d
4ccb9ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # Nemotron Diffusion | ||
|
|
||
| This directory contains recipes for training and running Nemotron Diffusion language models (dLLMs) based on Ministral-3 (3B, 8B, 14B). The full workflow is: | ||
|
|
||
| 0. **Bridge (Checkpoint Conversion)** — convert a HuggingFace Ministral-3 checkpoint to Megatron-Bridge format, and export trained checkpoints back to HuggingFace. | ||
| 1. **Continuous Pretraining (CPT)** — standard autoregressive pretraining on the base Ministral-3 model with additional data. | ||
| 2. **AR-to-DLM** — converts the CPT checkpoint into a diffusion language model using the block diffusion paradigm. | ||
| 3. **Inference** — run text generation from a trained checkpoint. | ||
|
|
||
| --- | ||
|
|
||
|
|
||
| ## Stage 1: Continuous Pretraining (CPT) | ||
|
|
||
| CPT fine-tunes a pretrained Ministral-3 model on new data using standard autoregressive cross-entropy loss. This stage adapts the model to the target domain before diffusion training. | ||
|
|
||
| **Example script:** | ||
| ```bash | ||
| torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/continuous_pretraining.py \ | ||
| --model-size 3b \ | ||
| --hf-path mistralai/Ministral-3-3B-Base-2512 \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/cpt_3b.yaml | ||
| ``` | ||
|
|
||
| Available config files: [`conf/cpt_3b.yaml`](conf/cpt_3b.yaml), [`conf/cpt_8b.yaml`](conf/cpt_8b.yaml), [`conf/cpt_14b.yaml`](conf/cpt_14b.yaml). | ||
|
|
||
| --- | ||
|
|
||
| ## Stage 2: AR-to-DLM | ||
|
|
||
| This stage converts the CPT checkpoint into a diffusion LM. It replaces the standard attention with `NemotronDiffusionAttention` and trains with a combined diffusion + AR loss. | ||
|
|
||
| **Key recipe:** `examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py` | ||
|
|
||
| The model is built via `NemotronDiffusionModelProvider`, which extends `Ministral3ModelProvider` with: | ||
| - `dlm_paradigm = "sbd_block_diff"` — semi-block diffusion with block masking | ||
| - `block_size = 64` — number of tokens per diffusion block | ||
| - `mask_token_id = 100` — token ID used for masking during diffusion | ||
| - `dlm_loss_weight = 0.3`, `ar_loss_weight = 1.0` — loss weighting between diffusion and AR objectives | ||
| - `NemotronDiffusionAttention` replaces core attention to support block-causal masking | ||
|
|
||
| The CPT checkpoint from Stage 1 is passed via `checkpoint.pretrained_checkpoint`. Setting `checkpoint.finetune=true` skips loading the optimizer state from the CPT stage. | ||
|
|
||
| **Example launch:** | ||
| ```bash | ||
| torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \ | ||
| --model-size 3b \ | ||
| --hf-path mistralai/Ministral-3-3B-Base-2512 \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \ | ||
| checkpoint.finetune=true \ | ||
| checkpoint.pretrained_checkpoint=/path/to/cpt_checkpoint | ||
| ``` | ||
|
|
||
| Available config files: [`conf/ar_to_dlm_3b_dlm.yaml`](conf/ar_to_dlm_3b_dlm.yaml), [`conf/ar_to_dlm_8b_dlm.yaml`](conf/ar_to_dlm_8b_dlm.yaml). | ||
|
|
||
| --- | ||
|
|
||
| ## Inference | ||
|
|
||
| The script [`inference_nemotron.py`](inference_nemotron.py) runs text generation from a trained Megatron-format NemotronDiffusion checkpoint. Both dLLM (block diffusion) and AR modes are supported. | ||
|
|
||
| ### dLLM mode (default) | ||
|
|
||
| ```bash | ||
| torchrun --nproc_per_node=4 examples/diffusion/recipes/nemotron_diffusion/inference_nemotron.py \ | ||
| --megatron-path /path/to/checkpoints/ar_to_dlm_8b \ | ||
| --hf-model mistralai/Ministral-3-8B-Base-2512 \ | ||
| --prompts "The capital of France is" \ | ||
| --gen-length 256 --block-length 32 --diffusion-steps 256 \ | ||
| --tp 4 | ||
| ``` | ||
|
|
||
| ### AR mode | ||
|
|
||
| ```bash | ||
| python examples/diffusion/recipes/nemotron_diffusion/inference_nemotron.py \ | ||
| --megatron-path /path/to/checkpoints/ar_to_dlm_3b \ | ||
| --hf-model mistralai/Ministral-3-3B-Base-2512 \ | ||
| --mode ar \ | ||
| --prompts "Once upon a time" \ | ||
| --max-new-tokens 128 | ||
| ``` | ||
|
|
||
| The `--tp` argument must match the tensor parallelism degree of the saved checkpoint (e.g. `--tp 4` for 8B checkpoints saved with TP=4). `--hf-model` is used for the tokenizer and model config only — weights are loaded from `--megatron-path`. | ||
|
|
||
| --- | ||
|
|
||
|
|
||
| ## Checkpoint Conversion (Bridge) | ||
|
|
||
| The `NemotronDiffusionBridge` converts between HuggingFace `Mistral3ForConditionalGeneration` and Megatron-Bridge distributed checkpoint format. It handles: | ||
|
|
||
| - **Language model weights** — mapped between HF (`language_model.model.*`) and Megatron (`language_model.decoder.*`) with proper QKV merging and tensor-parallel sharding. | ||
| - **Vision encoder weights** (`vision_tower.**`) — replicated across tensor-parallel ranks (no sharding needed). | ||
| - **Multimodal projector weights** (`multi_modal_projector.**`) — replicated similarly. | ||
|
|
||
| The conversion script is [`convert_checkpoints.py`](convert_checkpoints.py). | ||
|
|
||
| ### Import: HuggingFace → Megatron | ||
|
|
||
| ```bash | ||
| python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py import \ | ||
| --hf-model mistralai/Ministral-3-3B-Base-2512 \ | ||
| --megatron-path /path/to/checkpoints/hf_to_mb_3b \ | ||
| --torch-dtype bfloat16 | ||
| ``` | ||
|
|
||
| The Megatron checkpoint is written under `--megatron-path` (e.g. `.../hf_to_mb_3b/iter_0000000/`). Use the parent directory for CPT training with `checkpoint.load`. | ||
|
|
||
| For the 8B model (TP=4): | ||
| ```bash | ||
| python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py import \ | ||
| --hf-model mistralai/Ministral-3-8B-Base-2512 \ | ||
| --megatron-path /path/to/checkpoints/hf_to_mb_8b \ | ||
| --torch-dtype bfloat16 | ||
| ``` | ||
|
|
||
| ### Export: Megatron → HuggingFace | ||
|
|
||
| Export a trained Megatron checkpoint back to HuggingFace format. A reference HF model is required to provide config and tokenizer artifacts: | ||
|
|
||
| ```bash | ||
| python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py export \ | ||
| --hf-model mistralai/Ministral-3-3B-Base-2512 \ | ||
| --megatron-path /path/to/checkpoints/ar_to_dlm_3b \ | ||
| --hf-path /path/to/checkpoints/mb_to_hf_3b | ||
| ``` | ||
|
|
||
| The `--hf-model` argument is used as the reference for config, tokenizer, and any non-LM artifacts. The exported directory contains a self-contained HuggingFace model. | ||
|
|
||
| **Note:** If the reference HF model does not include vision tower weights (e.g. an LM-only checkpoint), warnings of the form `Can't find vision_tower.* in hf_keys` are expected and benign — the LM weights are still exported correctly. | ||
|
|
||
| --- | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| NemotronDiffusion diffusion LM pretraining. | ||
|
|
||
| Uses the sbd_block_diff diffusion paradigm via DGPTStep. | ||
| Use --hf-path to override the HuggingFace model ID or local model path. | ||
|
|
||
| Examples: | ||
| 3B model, first job from AR checkpoint (finetune=true skips optimizer state): | ||
| $ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \ | ||
| --hf-path mistralai/Ministral-3-3B-Base-2512 \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document \ | ||
| checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_3b \ | ||
| checkpoint.finetune=true | ||
|
|
||
| 3B model, subsequent jobs (resume from DLM checkpoint): | ||
| $ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \ | ||
| --hf-path mistralai/Ministral-3-3B-Base-2512 \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document | ||
|
|
||
| 8B model with TP=4: | ||
| $ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \ | ||
| --hf-path mistralai/Ministral-3-8B-Base-2512 \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_8b_dlm.yaml \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document \ | ||
| checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_8b \ | ||
| checkpoint.finetune=true | ||
|
|
||
| 14B model with TP=8: | ||
| $ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \ | ||
| --hf-path mistralai/Ministral-3-14B-Base-2512 \ | ||
| --config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_14b_dlm.yaml \ | ||
| --data-paths /path/to/dclm/merged_tokenized_text_document \ | ||
| checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_14b \ | ||
| checkpoint.finetune=true | ||
| """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add commands for example usage here, similar to other examples
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
|
||
| import argparse | ||
| import logging | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| from omegaconf import OmegaConf | ||
|
|
||
| # Register NemotronDiffusionBridge, overriding the base Ministral3Bridge so that | ||
| # AutoBridge returns NemotronDiffusionModelProvider (with NemotronDiffusionAttention). | ||
| import megatron.bridge.diffusion.conversion.nemotron_diffusion.nemotron_diffusion_bridge # noqa: F401 | ||
| from megatron.bridge.diffusion.models.common.dgpt_step import DGPTStep | ||
| from megatron.bridge.diffusion.recipes.nemotron_diffusion.ar_to_dlm import ( | ||
| nemotron_diffusion3_pretrain_config as pretrain_config, | ||
| ) | ||
sajadn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from megatron.bridge.training.config import ConfigContainer | ||
| from megatron.bridge.training.pretrain import pretrain | ||
| from megatron.bridge.training.utils.omegaconf_utils import ( | ||
| apply_overrides, | ||
| create_omegaconf_dict_config, | ||
| parse_hydra_overrides, | ||
| ) | ||
| from megatron.bridge.utils.common_utils import get_rank_safe | ||
|
|
||
|
|
||
| logger: logging.Logger = logging.getLogger(__name__) | ||
|
|
||
| SCRIPT_DIR: Path = Path(__file__).parent.parent.resolve() | ||
| DEFAULT_CONFIG_FILENAME: str = "train_local.yaml" | ||
| DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "override_configs" / DEFAULT_CONFIG_FILENAME | ||
|
|
||
|
|
||
| def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: | ||
| """Parse command-line arguments for the AR-to-DLM conversion script.""" | ||
| parser = argparse.ArgumentParser( | ||
| description="NemotronDiffusion diffusion LM pretraining (no distillation)", | ||
| formatter_class=argparse.RawTextHelpFormatter, | ||
| ) | ||
| parser.add_argument( | ||
| "--hf-path", | ||
| type=str, | ||
| default=None, | ||
| help="HuggingFace model ID or local path to model weights. Overrides the default for the selected model size.", | ||
| ) | ||
| parser.add_argument( | ||
| "--config-file", | ||
| type=str, | ||
| default=str(DEFAULT_CONFIG_FILE_PATH), | ||
| help="Path to YAML override file.", | ||
| ) | ||
| parser.add_argument("--debug", action="store_true", help="Enable debug logging") | ||
| parser.add_argument( | ||
| "--data-paths", | ||
| type=str, | ||
| nargs="*", | ||
| default=None, | ||
| help="List of dataset file paths (space or comma-separated).", | ||
| ) | ||
| parser.add_argument( | ||
| "--data-args-path", | ||
| type=str, | ||
| default=None, | ||
| help="Path to file containing data arguments.", | ||
| ) | ||
|
|
||
| args, cli_dotlist_overrides = parser.parse_known_args() | ||
|
|
||
| if args.data_paths: | ||
| flattened_paths = [] | ||
| for path in args.data_paths: | ||
| if "," in path: | ||
| flattened_paths.extend(path.split(",")) | ||
| else: | ||
| flattened_paths.append(path) | ||
| args.data_paths = [p.strip() for p in flattened_paths if p.strip()] | ||
|
|
||
| return args, cli_dotlist_overrides | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Entry point for AR-to-DLM conversion and continued pretraining.""" | ||
| args, cli_overrides = parse_cli_args() | ||
| cfg: ConfigContainer = pretrain_config( | ||
| data_paths=args.data_paths, | ||
| data_args_path=args.data_args_path, | ||
| hf_path=args.hf_path, | ||
| ) | ||
|
|
||
| if get_rank_safe() == 0: | ||
| cfg.print_yaml() | ||
|
|
||
| merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) | ||
|
|
||
| if args.config_file: | ||
| if not os.path.exists(args.config_file): | ||
| logger.error(f"Override YAML file not found: {args.config_file}") | ||
| sys.exit(1) | ||
| yaml_overrides_omega = OmegaConf.load(args.config_file) | ||
| merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) | ||
|
|
||
| if cli_overrides: | ||
| merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) | ||
|
|
||
| final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) | ||
| apply_overrides(cfg, final_overrides_as_dict, excluded_fields) | ||
|
|
||
| if get_rank_safe() == 0: | ||
| logger.info("--- Final Merged Configuration ---") | ||
| cfg.print_yaml() | ||
| logger.info("----------------------------------") | ||
|
|
||
| pretrain(config=cfg, forward_step_func=DGPTStep()) | ||
|
|
||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we generally don't want to have these yaml config file in the repo. Can we include all the hyperparameters in the recipe function? then if anything else is needed, use hydra override in the example command
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, makes sense. I'll remove them. Btw, other diffusion models like wan and flux seems to have their yaml file in the repo: https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/diffusion/recipes/wan/conf |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # NemotronDiffusion 14B diffusion LM pretraining config | ||
|
|
||
| cluster_config: | ||
| hostname: nrt | ||
| account_name: ${oc.env:USER} | ||
| job_name: nemotron_diffusion_14b_sbd64_dlm | ||
| user_path: /lustre/fsw/portfolios/coreai/users | ||
| exp_dir: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/${cluster_config.job_name} | ||
|
|
||
| model: | ||
| seq_length: 4096 | ||
| block_size: 64 | ||
| max_position_embeddings: ${model.seq_length} | ||
| tensor_model_parallel_size: 8 | ||
| sequence_parallelism: false | ||
| pipeline_model_parallel_size: 1 | ||
| cross_entropy_loss_fusion: true | ||
| cross_entropy_fusion_impl: te | ||
|
|
||
| train: | ||
| micro_batch_size: 1 | ||
| global_batch_size: 512 | ||
| train_iters: 12500 | ||
| eval_iters: 10 | ||
| eval_interval: 1000 | ||
| exit_duration_in_mins: 220 | ||
|
|
||
| optimizer: | ||
| use_distributed_optimizer: true | ||
| lr: 1e-5 | ||
| min_lr: 1e-6 | ||
| weight_decay: 0.1 | ||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.95 | ||
| adam_eps: 1e-8 | ||
| clip_grad: 1.0 | ||
|
|
||
| scheduler: | ||
| lr_decay_style: WSD | ||
| lr_warmup_fraction: 0.01 | ||
| lr_wsd_decay_iters: 2500 | ||
| lr_warmup_iters: 0 | ||
|
|
||
| checkpoint: | ||
| save: ${cluster_config.exp_dir} | ||
| load: ${cluster_config.exp_dir} | ||
| save_interval: 1000 | ||
| auto_detect_ckpt_format: true | ||
|
|
||
| dist: | ||
| distributed_timeout_minutes: 240 | ||
|
|
||
| ddp: | ||
| use_distributed_optimizer: true | ||
|
|
||
| mixed_precision: bf16_mixed | ||
|
|
||
| logger: | ||
| log_interval: 10 | ||
| tensorboard_dir: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/tensorboard | ||
| wandb_project: megatron | ||
| wandb_exp_name: ${cluster_config.job_name} | ||
| wandb_save_dir: ${cluster_config.exp_dir}/wandb | ||
|
|
||
| tokenizer: | ||
| tokenizer_type: HuggingFaceTokenizer | ||
| tokenizer_model: mistralai/Ministral-3-14B-Base-2512 | ||
|
|
||
| dataset: | ||
| sequence_length: ${model.seq_length} | ||
| split: "950,50,0" | ||
| path_to_cache: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/data_cache | ||
| mmap_bin_files: false | ||
| dataloader_type: cyclic | ||
| num_workers: 10 |
Uh oh!
There was an error while loading. Please reload this page.