Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions examples/diffusion/recipes/nemotron_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Nemotron Diffusion

This directory contains recipes for training and running Nemotron Diffusion language models (dLLMs) based on Ministral-3 (3B, 8B, 14B). The full workflow is:

0. **Bridge (Checkpoint Conversion)** — convert a HuggingFace Ministral-3 checkpoint to Megatron-Bridge format, and export trained checkpoints back to HuggingFace.
1. **Continuous Pretraining (CPT)** — standard autoregressive pretraining on the base Ministral-3 model with additional data.
2. **AR-to-DLM** — converts the CPT checkpoint into a diffusion language model using the block diffusion paradigm.
3. **Inference** — run text generation from a trained checkpoint.

---


## Stage 1: Continuous Pretraining (CPT)

CPT fine-tunes a pretrained Ministral-3 model on new data using standard autoregressive cross-entropy loss. This stage adapts the model to the target domain before diffusion training.

**Example script:**
```bash
torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/continuous_pretraining.py \
--model-size 3b \
--hf-path mistralai/Ministral-3-3B-Base-2512 \
--data-paths /path/to/dclm/merged_tokenized_text_document \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/cpt_3b.yaml
```

Available config files: [`conf/cpt_3b.yaml`](conf/cpt_3b.yaml), [`conf/cpt_8b.yaml`](conf/cpt_8b.yaml), [`conf/cpt_14b.yaml`](conf/cpt_14b.yaml).

---

## Stage 2: AR-to-DLM

This stage converts the CPT checkpoint into a diffusion LM. It replaces the standard attention with `NemotronDiffusionAttention` and trains with a combined diffusion + AR loss.

**Key recipe:** `examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py`

The model is built via `NemotronDiffusionModelProvider`, which extends `Ministral3ModelProvider` with:
- `dlm_paradigm = "sbd_block_diff"` — semi-block diffusion with block masking
- `block_size = 64` — number of tokens per diffusion block
- `mask_token_id = 100` — token ID used for masking during diffusion
- `dlm_loss_weight = 0.3`, `ar_loss_weight = 1.0` — loss weighting between diffusion and AR objectives
- `NemotronDiffusionAttention` replaces core attention to support block-causal masking

The CPT checkpoint from Stage 1 is passed via `checkpoint.pretrained_checkpoint`. Setting `checkpoint.finetune=true` skips loading the optimizer state from the CPT stage.

**Example launch:**
```bash
torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \
--model-size 3b \
--hf-path mistralai/Ministral-3-3B-Base-2512 \
--data-paths /path/to/dclm/merged_tokenized_text_document \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \
checkpoint.finetune=true \
checkpoint.pretrained_checkpoint=/path/to/cpt_checkpoint
```

Available config files: [`conf/ar_to_dlm_3b_dlm.yaml`](conf/ar_to_dlm_3b_dlm.yaml), [`conf/ar_to_dlm_8b_dlm.yaml`](conf/ar_to_dlm_8b_dlm.yaml).

---

## Inference

The script [`inference_nemotron.py`](inference_nemotron.py) runs text generation from a trained Megatron-format NemotronDiffusion checkpoint. Both dLLM (block diffusion) and AR modes are supported.

### dLLM mode (default)

```bash
torchrun --nproc_per_node=4 examples/diffusion/recipes/nemotron_diffusion/inference_nemotron.py \
--megatron-path /path/to/checkpoints/ar_to_dlm_8b \
--hf-model mistralai/Ministral-3-8B-Base-2512 \
--prompts "The capital of France is" \
--gen-length 256 --block-length 32 --diffusion-steps 256 \
--tp 4
```

### AR mode

```bash
python examples/diffusion/recipes/nemotron_diffusion/inference_nemotron.py \
--megatron-path /path/to/checkpoints/ar_to_dlm_3b \
--hf-model mistralai/Ministral-3-3B-Base-2512 \
--mode ar \
--prompts "Once upon a time" \
--max-new-tokens 128
```

The `--tp` argument must match the tensor parallelism degree of the saved checkpoint (e.g. `--tp 4` for 8B checkpoints saved with TP=4). `--hf-model` is used for the tokenizer and model config only — weights are loaded from `--megatron-path`.

---


## Checkpoint Conversion (Bridge)

The `NemotronDiffusionBridge` converts between HuggingFace `Mistral3ForConditionalGeneration` and Megatron-Bridge distributed checkpoint format. It handles:

- **Language model weights** — mapped between HF (`language_model.model.*`) and Megatron (`language_model.decoder.*`) with proper QKV merging and tensor-parallel sharding.
- **Vision encoder weights** (`vision_tower.**`) — replicated across tensor-parallel ranks (no sharding needed).
- **Multimodal projector weights** (`multi_modal_projector.**`) — replicated similarly.

The conversion script is [`convert_checkpoints.py`](convert_checkpoints.py).

### Import: HuggingFace → Megatron

```bash
python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py import \
--hf-model mistralai/Ministral-3-3B-Base-2512 \
--megatron-path /path/to/checkpoints/hf_to_mb_3b \
--torch-dtype bfloat16
```

The Megatron checkpoint is written under `--megatron-path` (e.g. `.../hf_to_mb_3b/iter_0000000/`). Use the parent directory for CPT training with `checkpoint.load`.

For the 8B model (TP=4):
```bash
python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py import \
--hf-model mistralai/Ministral-3-8B-Base-2512 \
--megatron-path /path/to/checkpoints/hf_to_mb_8b \
--torch-dtype bfloat16
```

### Export: Megatron → HuggingFace

Export a trained Megatron checkpoint back to HuggingFace format. A reference HF model is required to provide config and tokenizer artifacts:

```bash
python examples/diffusion/recipes/nemotron_diffusion/convert_checkpoints.py export \
--hf-model mistralai/Ministral-3-3B-Base-2512 \
--megatron-path /path/to/checkpoints/ar_to_dlm_3b \
--hf-path /path/to/checkpoints/mb_to_hf_3b
```

The `--hf-model` argument is used as the reference for config, tokenizer, and any non-LM artifacts. The exported directory contains a self-contained HuggingFace model.

**Note:** If the reference HF model does not include vision tower weights (e.g. an LM-only checkpoint), warnings of the form `Can't find vision_tower.* in hf_keys` are expected and benign — the LM weights are still exported correctly.

---
Empty file.
175 changes: 175 additions & 0 deletions examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
NemotronDiffusion diffusion LM pretraining.

Uses the sbd_block_diff diffusion paradigm via DGPTStep.
Use --hf-path to override the HuggingFace model ID or local model path.

Examples:
3B model, first job from AR checkpoint (finetune=true skips optimizer state):
$ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \
--hf-path mistralai/Ministral-3-3B-Base-2512 \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \
--data-paths /path/to/dclm/merged_tokenized_text_document \
checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_3b \
checkpoint.finetune=true

3B model, subsequent jobs (resume from DLM checkpoint):
$ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \
--hf-path mistralai/Ministral-3-3B-Base-2512 \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_3b_dlm.yaml \
--data-paths /path/to/dclm/merged_tokenized_text_document

8B model with TP=4:
$ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \
--hf-path mistralai/Ministral-3-8B-Base-2512 \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_8b_dlm.yaml \
--data-paths /path/to/dclm/merged_tokenized_text_document \
checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_8b \
checkpoint.finetune=true

14B model with TP=8:
$ torchrun --nproc_per_node=8 examples/diffusion/recipes/nemotron_diffusion/ar_to_dlm.py \
--hf-path mistralai/Ministral-3-14B-Base-2512 \
--config-file examples/diffusion/recipes/nemotron_diffusion/conf/ar_to_dlm_14b_dlm.yaml \
--data-paths /path/to/dclm/merged_tokenized_text_document \
checkpoint.pretrained_checkpoint=/path/to/hf_to_mb_14b \
checkpoint.finetune=true
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add commands for example usage here, similar to other examples

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Tuple

import torch
from omegaconf import OmegaConf

# Register NemotronDiffusionBridge, overriding the base Ministral3Bridge so that
# AutoBridge returns NemotronDiffusionModelProvider (with NemotronDiffusionAttention).
import megatron.bridge.diffusion.conversion.nemotron_diffusion.nemotron_diffusion_bridge # noqa: F401
from megatron.bridge.diffusion.models.common.dgpt_step import DGPTStep
from megatron.bridge.diffusion.recipes.nemotron_diffusion.ar_to_dlm import (
nemotron_diffusion3_pretrain_config as pretrain_config,
)
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.utils.omegaconf_utils import (
apply_overrides,
create_omegaconf_dict_config,
parse_hydra_overrides,
)
from megatron.bridge.utils.common_utils import get_rank_safe


logger: logging.Logger = logging.getLogger(__name__)

SCRIPT_DIR: Path = Path(__file__).parent.parent.resolve()
DEFAULT_CONFIG_FILENAME: str = "train_local.yaml"
DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "override_configs" / DEFAULT_CONFIG_FILENAME


def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
"""Parse command-line arguments for the AR-to-DLM conversion script."""
parser = argparse.ArgumentParser(
description="NemotronDiffusion diffusion LM pretraining (no distillation)",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--hf-path",
type=str,
default=None,
help="HuggingFace model ID or local path to model weights. Overrides the default for the selected model size.",
)
parser.add_argument(
"--config-file",
type=str,
default=str(DEFAULT_CONFIG_FILE_PATH),
help="Path to YAML override file.",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument(
"--data-paths",
type=str,
nargs="*",
default=None,
help="List of dataset file paths (space or comma-separated).",
)
parser.add_argument(
"--data-args-path",
type=str,
default=None,
help="Path to file containing data arguments.",
)

args, cli_dotlist_overrides = parser.parse_known_args()

if args.data_paths:
flattened_paths = []
for path in args.data_paths:
if "," in path:
flattened_paths.extend(path.split(","))
else:
flattened_paths.append(path)
args.data_paths = [p.strip() for p in flattened_paths if p.strip()]

return args, cli_dotlist_overrides


def main() -> None:
"""Entry point for AR-to-DLM conversion and continued pretraining."""
args, cli_overrides = parse_cli_args()
cfg: ConfigContainer = pretrain_config(
data_paths=args.data_paths,
data_args_path=args.data_args_path,
hf_path=args.hf_path,
)

if get_rank_safe() == 0:
cfg.print_yaml()

merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg)

if args.config_file:
if not os.path.exists(args.config_file):
logger.error(f"Override YAML file not found: {args.config_file}")
sys.exit(1)
yaml_overrides_omega = OmegaConf.load(args.config_file)
merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega)

if cli_overrides:
merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides)

final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True)
apply_overrides(cfg, final_overrides_as_dict, excluded_fields)

if get_rank_safe() == 0:
logger.info("--- Final Merged Configuration ---")
cfg.print_yaml()
logger.info("----------------------------------")

pretrain(config=cfg, forward_step_func=DGPTStep())

if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we generally don't want to have these yaml config file in the repo. Can we include all the hyperparameters in the recipe function? then if anything else is needed, use hydra override in the example command

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, makes sense. I'll remove them. Btw, other diffusion models like wan and flux seems to have their yaml file in the repo: https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/diffusion/recipes/wan/conf

Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# NemotronDiffusion 14B diffusion LM pretraining config

cluster_config:
hostname: nrt
account_name: ${oc.env:USER}
job_name: nemotron_diffusion_14b_sbd64_dlm
user_path: /lustre/fsw/portfolios/coreai/users
exp_dir: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/${cluster_config.job_name}

model:
seq_length: 4096
block_size: 64
max_position_embeddings: ${model.seq_length}
tensor_model_parallel_size: 8
sequence_parallelism: false
pipeline_model_parallel_size: 1
cross_entropy_loss_fusion: true
cross_entropy_fusion_impl: te

train:
micro_batch_size: 1
global_batch_size: 512
train_iters: 12500
eval_iters: 10
eval_interval: 1000
exit_duration_in_mins: 220

optimizer:
use_distributed_optimizer: true
lr: 1e-5
min_lr: 1e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1e-8
clip_grad: 1.0

scheduler:
lr_decay_style: WSD
lr_warmup_fraction: 0.01
lr_wsd_decay_iters: 2500
lr_warmup_iters: 0

checkpoint:
save: ${cluster_config.exp_dir}
load: ${cluster_config.exp_dir}
save_interval: 1000
auto_detect_ckpt_format: true

dist:
distributed_timeout_minutes: 240

ddp:
use_distributed_optimizer: true

mixed_precision: bf16_mixed

logger:
log_interval: 10
tensorboard_dir: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/tensorboard
wandb_project: megatron
wandb_exp_name: ${cluster_config.job_name}
wandb_save_dir: ${cluster_config.exp_dir}/wandb

tokenizer:
tokenizer_type: HuggingFaceTokenizer
tokenizer_model: mistralai/Ministral-3-14B-Base-2512

dataset:
sequence_length: ${model.seq_length}
split: "950,50,0"
path_to_cache: ${cluster_config.user_path}/${cluster_config.account_name}/megatron_exp/data_cache
mmap_bin_files: false
dataloader_type: cyclic
num_workers: 10
Loading
Loading