Complete guide for training Neural Operator Agents (NOA) with various training approaches.
For new projects, use the CNO-Trained Architecture:
- ✅ Simpler: Two independent training pipelines (no sequential dependency)
- ✅ Modular: Each component validated independently on CNO ground truth
- ✅ Efficient: No need to generate 100K+ MNO rollouts for VQ training
- ✅ Parallel: VQ-VAE and MNO can be trained simultaneously
- ✅ Easier to debug: Components trained independently
Summary:
- Train VQ-VAE on CNO ground truth features → discrete symbolic representation
- Train MNO on CNO ground truth trajectories → high-fidelity physics simulator
- Post-training validation → verify VQ reconstruction on MNO outputs
This guide documents old training approaches:
- 3-stage sequential (deprecated): Train MNO → generate MNO features → train VQ-VAE on MNO distribution. See Independent Optimization (Deprecated) for details.
- Two-stage curriculum (deprecated): Stage 1 with token conditioning + Stage 2 VQ-led fine-tuning. See Two-Stage Curriculum Architecture for details and lessons learned.
- Simultaneous training (below): VQ-VAE alignment during NOA training. More complex, lower quality.
- Overview
- Training Architecture
- Quick Start
- Training Configuration
- Loss Functions
- Checkpointing and Resume
- Hyperparameter Tuning
- Diagnostics
- Troubleshooting
NOA training uses state-level supervision with optional VQ-VAE alignment to learn physics-native rollout generation. The training objective combines three complementary losses:
L_total = L_traj + λ_commit·L_commit + λ_latent·L_latent
- L_traj: MSE between NOA predictions and CNO ground truth (physics fidelity)
- L_commit: VQ-VAE commitment loss (manifold adherence)
- L_latent: NOA-VQ latent alignment loss (representation learning)
- U-AFNO Backbone: Physics-native neural operator with spectral mixing
- Truncated BPTT: Prevents gradient explosion for long sequences (T=256)
- Three-Loss Training: Learns physics + behavioral representations simultaneously
- Checkpoint Resume: Robust resumption from training interruptions
- Diagnostic Tools: Comprehensive alignment quality evaluation
Input: (IC, operator_params) → CNO rollout (ground truth)
→ NOA rollout (predicted)
↓
┌─────────────────────────────────┐
│ Loss Computation │
├─────────────────────────────────┤
│ L_traj = MSE(NOA, CNO) │
│ L_commit = VQ commitment │
│ L_latent = NOA ↔ VQ alignment │
└─────────────────────────────────┘
↓
Backprop through last 32 steps (TBPTT)
- U-AFNO Operator: Spectral mixing in Fourier domain
- Latent Projector: Maps U-AFNO bottleneck → VQ latent space (optional, for L_latent)
- CNO Replayer: Generates ground truth trajectories from saved parameters
- VQ-VAE Encoder: Extracts behavioral features → discrete tokens (frozen)
Train NOA to match CNO rollouts without VQ-VAE alignment:
poetry run python scripts/dev/train_noa_state_supervised.py \
--dataset datasets/100k_full_features.h5 \
--n-samples 5000 \
--epochs 10 \
--batch-size 4 \
--lr 3e-4 \
--bptt-window 32 \
--timesteps 256Expected: L_traj decreases from ~600 → <10 over 10 epochs.
Add VQ-VAE commitment loss for better manifold adherence:
poetry run python scripts/dev/train_noa_state_supervised.py \
--dataset datasets/100k_full_features.h5 \
--vqvae-path checkpoints/production/100k_3family_v1 \
--n-samples 5000 \
--epochs 10 \
--batch-size 4 \
--lr 3e-4 \
--bptt-window 32 \
--timesteps 256 \
--lambda-commit 0.5Expected: L_commit stays low (~0.0005), indicating NOA outputs are VQ-tokenizable.
Enable latent alignment for representation learning:
poetry run python scripts/dev/train_noa_state_supervised.py \
--dataset datasets/100k_full_features.h5 \
--vqvae-path checkpoints/production/100k_3family_v1 \
--n-samples 5000 \
--epochs 10 \
--batch-size 4 \
--lr 3e-4 \
--bptt-window 32 \
--warmup-steps 500 \
--timesteps 256 \
--lambda-commit 0.5 \
--enable-latent-loss \
--lambda-latent 0.5 \
--latent-sample-steps 8 \
--save-every 200Expected:
L_traj: 600 → <10L_commit: ~0.0005 (stable)L_latent: 0.7 → 0.5 (alignment improving)
| Argument | Description | Recommended Value |
|---|---|---|
--dataset |
Path to HDF5 dataset | datasets/100k_full_features.h5 |
--n-samples |
Number of training samples | 5000-10000 |
--epochs |
Training epochs | 10-20 |
--batch-size |
Batch size (GPU memory limited) | 4 |
--lr |
Learning rate | 3e-4 |
--bptt-window |
Truncated BPTT window | 32 |
--timesteps |
Rollout length | 256 |
| Argument | Description | Recommended Value |
|---|---|---|
--vqvae-path |
Path to VQ-VAE checkpoint | checkpoints/production/100k_3family_v1 |
--lambda-commit |
Commitment loss weight | 0.5 |
--enable-latent-loss |
Enable L_latent | (flag) |
--lambda-latent |
Latent alignment weight | 0.1-0.5 |
--latent-sample-steps |
Timesteps to sample for L_latent | 3-8 |
| Argument | Description | Recommended Value |
|---|---|---|
--lr-schedule |
Schedule type | cosine |
--warmup-steps |
Warmup batches | 500 |
| Argument | Description | Recommended Value |
|---|---|---|
--checkpoint-dir |
Checkpoint directory | checkpoints/noa |
--save-every |
Save every N batches | 200 |
--early-stop-patience |
Stop if no improvement for N epochs | 2 |
| Argument | Description | Recommended Value |
|---|---|---|
--base-channels |
U-AFNO base channels | 32 |
--encoder-levels |
U-Net encoder levels | 3 |
--modes |
Fourier modes | 16 |
--afno-blocks |
AFNO blocks per level | 4 |
What it measures: How well NOA matches CNO ground truth trajectories.
Computation:
L_traj = MSE(NOA_rollout, CNO_rollout) # [B, T, C, H, W]Interpretation:
L_traj = 600: Random initializationL_traj = 50-100: Learning basic dynamicsL_traj = 10-20: Good physics matchingL_traj < 5: Excellent physics fidelity
Why it's needed: Core objective for learning operator dynamics.
What it measures: How easily VQ-VAE can tokenize NOA outputs.
Computation:
features = extract_features(NOA_rollout)
z = VQ_encode(features) # Pre-quantization latents
z_q = quantize(z) # Nearest codebook vectors
L_commit = MSE(z, z_q.detach())Interpretation:
L_commit ≈ 0.0005: NOA outputs are on VQ manifold (good)L_commit > 0.001: NOA drifting off manifold (concerning)L_commit increasing: NOA learning physics that VQ-VAE can't represent
Why it's needed: Ensures NOA outputs remain tokenizable for downstream applications.
What it measures: Alignment between NOA's internal features and VQ-VAE's learned embeddings.
Computation:
# Extract NOA bottleneck features [B, 256, 8, 8]
noa_bottleneck = NOA.get_intermediate_features(state_t, "bottleneck")
# Project to VQ space [B, 780]
noa_latents = projector(noa_bottleneck)
# Get VQ latents from features
features = extract_features(NOA_rollout)
vq_latents = VQ_encode(features)
# Align (sample N timesteps for efficiency)
L_latent = MSE(mean(noa_latents_sampled), vq_latents.detach())Interpretation:
L_latent = 0.7: Random initializationL_latent = 0.5: Moderate alignment (good)L_latent = 0.3: Strong alignment (excellent)L_latent < 0.1: Very strong alignment (rare)
Why it's needed:
- Encourages NOA to learn VQ-VAE's behavioral representations
- Enables interpretability (NOA features → VQ codes)
- Improves transfer learning to downstream tasks
Memory tradeoff: Sampling fewer timesteps (3-8) reduces overhead from 48% → 15-20%.
New checkpoints (saved after implementing resume) include:
checkpoint = {
"model_state_dict": ..., # NOA weights
"optimizer_state_dict": ..., # Adam state
"scheduler_state_dict": ..., # LR schedule
"epoch": 3, # Current epoch (0-indexed)
"global_step": 675, # Total batches processed
"history": { # Loss curves
"train_loss": [...],
"val_loss": [...]
},
"best_val_loss": 12.345, # Best validation so far
"alignment_state": ..., # Latent projector weights (if L_latent enabled)
"config": { # Model architecture
"base_channels": 32,
"encoder_levels": 3,
"modes": 16,
"afno_blocks": 4,
},
"args": {...} # Full training args
}Checkpoints are saved automatically:
- Periodic: Every
--save-everybatches (e.g.,step_200.pt,step_400.pt) - Per-Epoch: After each epoch (
epoch_1.pt,epoch_2.pt) - Best Model: When validation loss improves (
best_model.pt)
# Resume from last saved checkpoint
poetry run python scripts/dev/train_noa_state_supervised.py \
--resume checkpoints/noa/epoch_5.pt \
--dataset datasets/100k_full_features.h5 \
--vqvae-path checkpoints/production/100k_3family_v1 \
--epochs 10 \
--batch-size 4 \
--enable-latent-loss \
--lambda-latent 0.5What happens:
- Loads model, optimizer, scheduler state
- Resumes from epoch 5, continues to epoch 10
- Preserves training history and best validation loss
- LR schedule continues from correct step (no warmup restart)
- Loads projector weights if L_latent was enabled
# Resume from step checkpoint (e.g., batch 200 of epoch 1)
poetry run python scripts/dev/train_noa_state_supervised.py \
--resume checkpoints/noa/step_200.pt \
--dataset datasets/100k_full_features.h5 \
--vqvae-path checkpoints/production/100k_3family_v1 \
--epochs 5 \
--batch-size 4 \
--enable-latent-lossWhat happens:
- Detects checkpoint is from batch 200 of epoch 1
- Skips first 200 batches of epoch 1 (already processed)
- Continues from batch 201
- LR synced to step 200 (correct value, no warmup)
Resuming from checkpoint: checkpoints/noa/step_200.pt
✓ Loaded model weights
✓ Loaded optimizer state
✓ Loaded scheduler state
✓ Resuming from epoch 1, step 200
(Will skip first 200 batches of epoch 1)
✓ Best val loss so far: 15.234567
✓ Loaded latent projector weights
Epoch 1/5
Skipping first 200 batches (already processed)...
[201/225] loss=13.88 traj=13.54 commit=0.000561 latent=0.534 lr=1.51e-04 8.2s/b
[202/225] loss=13.82 traj=13.48 commit=0.000560 latent=0.533 lr=1.51e-04 8.1s/b
...
If resuming from checkpoints saved before resume functionality was added:
Resuming from checkpoint: checkpoints/noa/step_200.pt
✓ Loaded model weights
✓ Loaded optimizer state
⚠ Old checkpoint format detected - inferred global_step=200 from filename
⚠ Syncing scheduler to step 200...
✓ Scheduler synced to step 200, lr=1.50e-04
✓ Resuming from epoch 1, step 200
(Will skip first 200 batches of epoch 1)
✓ No validation history (first epoch incomplete)
⚠⚠ WARNING: Old checkpoint has no projector weights!
Projector will restart from random initialization.
L_latent training will be inconsistent with pre-crash training.
Recommend: Either disable --enable-latent-loss or train from scratch.
Recommendations:
- If L_latent is critical: Train from scratch to get consistent projector training
- If L_latent is optional: Disable
--enable-latent-lossand resume with L_commit only - Accept inconsistency: Resume with L_latent, but projector will reinitialize (loss curve will jump)
- Save frequently: Use
--save-every 200for large datasets - Monitor checkpoints: Check
checkpoints/noa/periodically to ensure saves are working - Keep best model: Always preserve
best_model.ptfor deployment - Clean up: Delete old
step_*.ptfiles to save disk space - Test resume: After first epoch, try resuming to verify checkpoint format
What it controls: How strongly NOA is pushed toward VQ manifold.
| Value | Effect | When to Use |
|---|---|---|
| 0.0 | No VQ alignment | Testing physics learning only |
| 0.1 | Weak alignment | VQ-VAE already well-matched to data |
| 0.5 | Recommended | Standard training |
| 1.0 | Strong alignment | NOA drifting off manifold |
| 2.0+ | Very strong | Force manifold adherence (may hurt physics) |
Tuning guide:
- If
L_commitincreasing during training → increase λ_commit - If
L_trajnot decreasing → decrease λ_commit (too much constraint)
What it controls: How strongly NOA's internal features align with VQ latents.
| Value | Effect | When to Use |
|---|---|---|
| 0.0 | No latent alignment | Baseline (L_traj + L_commit only) |
| 0.1 | Weak alignment | Initial experiments |
| 0.5 | Recommended | Strong alignment without compromising physics |
| 1.0 | Very strong | Prioritize representation learning |
| 2.0+ | Dominant | Force alignment (may hurt physics) |
Tuning guide:
- Start with 0.1, increase to 0.5 if
L_latentplateaus - If
L_trajconvergence slows → decrease λ_latent - If
L_latentdoesn't decrease → increase λ_latent or--latent-sample-steps
Ablation results (preliminary):
λ_latent=0.1, n_samples=3: L_latent: 0.646 → 0.580 (plateau)
λ_latent=0.5, n_samples=8: L_latent: 0.705 → 0.561 (2× faster, breaks plateau)
What it controls: How many trajectory timesteps to sample for L_latent computation.
| Value | Memory Overhead | Latent Loss Quality | When to Use |
|---|---|---|---|
| 3 | +15% | Good (first, middle, last) | Recommended, memory limited |
| 8 | +48% | Better (rich temporal context) | Strong alignment needed |
| -1 | +200% | Best (all timesteps) | Small BPTT windows only |
Tradeoff: More samples → richer alignment signal but slower training.
| GPU VRAM | Batch Size | Notes |
|---|---|---|
| 8 GB | 1-2 | May OOM with L_latent |
| 16 GB | 4 | Recommended |
| 24 GB | 8 | Faster convergence |
| 40 GB+ | 16 | Diminishing returns |
If OOM:
- Reduce
--batch-size(4 → 2) - Reduce
--latent-sample-steps(8 → 3) - Disable
--enable-latent-loss
Recommended: Cosine annealing with warmup
--lr 3e-4 \
--lr-schedule cosine \
--warmup-steps 500Why warmup: Prevents early instability when optimizer hasn't seen data yet.
Warmup schedule:
Steps 0-500: LR ramps 3e-5 → 3e-4 (linear)
Steps 500+: LR decays 3e-4 → 0 (cosine)
Monitor these metrics every epoch:
Epoch 5/10
Train: total=8.5432 traj=8.0123 commit=0.000543 latent=0.529 [1234.5s]
Val: total=9.1234 traj=8.5678 commit=0.000556 latent=0.556
Health checks:
- ✅
L_trajdecreasing steadily - ✅
L_commitstable around 0.0005 - ✅
L_latentdecreasing (if enabled) - ❌
L_commitincreasing → NOA drifting off manifold - ❌
L_latentstuck → increase λ_latent or sample more timesteps - ❌
L_trajnot decreasing → learning rate too high/low
Run comprehensive diagnostic after training:
poetry run python scripts/dev/diagnose_latent_alignment.py \
--noa-checkpoint checkpoints/noa/best_model.pt \
--vqvae-path checkpoints/production/100k_3family_v1 \
--dataset datasets/100k_full_features.h5 \
--n-samples 100 \
--timesteps 256Output:
============================================================
L_latent Alignment Diagnostics
============================================================
1. VQ Reconstruction Quality
Total MSE: 0.3245
Per-category reconstruction errors:
INITIAL : 0.2134
SUMMARY : 0.3891
TEMPORAL : 0.3710
➜ Quality Assessment: Good
2. Token Diversity
Overall utilization: 67.3% (1234/1834 codes)
Token entropy: 5.83
➜ Diversity Assessment: Good
3. Alignment Correlation
Cosine similarity: 0.623 ± 0.084
L_latent (MSE): 0.521
➜ Correlation Assessment: Moderate
4. Temporal Consistency
Latent norm: 12.345 ± 0.678
Coefficient of variation: 0.055
➜ Consistency Assessment: Good
Overall Summary
VQ Reconstruction: Good
Token Diversity: Good
Alignment Correlation: Moderate
Temporal Consistency: Good
➜ Final Verdict: GOOD - L_latent provides meaningful alignment
Interpretation:
- VQ Reconstruction < 0.5: NOA outputs tokenize well
- Token Diversity > 50%: NOA explores diverse behaviors
- Cosine Similarity > 0.5: Moderate alignment achieved
- CV < 0.1: Stable alignment across trajectory
See Diagnostics section for detailed metric definitions.
Symptoms:
RuntimeError: CUDA out of memory. Tried to allocate X.XX GiB
Solutions:
- Reduce batch size:
--batch-size 4→--batch-size 2 - Reduce latent sampling:
--latent-sample-steps 8→--latent-sample-steps 3 - Disable L_latent: Remove
--enable-latent-loss - Check GPU usage:
nvidia-smito see if other processes are using memory
Symptoms:
Warning: NaN/Inf gradients at batch 157, skipping update
Solutions:
- Already handled: Training skips NaN batches automatically
- If frequent (>10% of batches): Reduce
--lr(3e-4 → 1e-4) - If persistent: Check dataset for NaN values
Symptoms:
Epoch 5: latent=0.65
Epoch 10: latent=0.64 (barely changed)
Solutions:
- Increase λ_latent:
0.1→0.5 - Sample more timesteps:
--latent-sample-steps 3→--latent-sample-steps 8 - Train longer: May need 20+ epochs for strong alignment
- Check projector is learning: Load checkpoint and verify weights changed
Symptoms:
Epoch 1: commit=0.0005
Epoch 5: commit=0.0015 (increasing!)
Root cause: NOA learning physics that VQ-VAE can't represent.
Solutions:
- Increase λ_commit:
0.5→1.0 - Check VQ-VAE quality: May need to retrain VQ-VAE on more diverse data
- Reduce λ_latent: May be pulling NOA off manifold
Symptom:
⚠⚠ WARNING: Old checkpoint has no projector weights!
Projector will restart from random initialization.
Solution:
- Train from scratch (recommended for clean L_latent)
- Disable
--enable-latent-lossand resume without L_latent - Accept inconsistency (projector reinitializes)
Symptom: Loss at batch 1 of resumed training doesn't match loss at batch 225 of crashed training.
Root cause: DataLoader reshuffles data each epoch.
Solution: This is expected! Training loop now skips already-processed batches, so loss values will match the original run once it reaches batch 201.
If using a different VQ-VAE architecture:
# System automatically infers VQ latent dimension
poetry run python scripts/dev/train_noa_state_supervised.py \
--vqvae-path checkpoints/my_custom_vqvae \
--enable-latent-lossSupported: Any VQ-VAE with .encode() method that returns list of latents.
Not yet supported. Stay tuned for distributed training implementation.
Use trained NOA as initialization for domain-specific tasks:
# Train on general dataset
poetry run python scripts/dev/train_noa_state_supervised.py \
--dataset datasets/100k_full_features.h5 \
--epochs 20 \
--checkpoint-dir checkpoints/noa_pretrained
# Fine-tune on domain-specific data
poetry run python scripts/dev/train_noa_state_supervised.py \
--resume checkpoints/noa_pretrained/best_model.pt \
--dataset datasets/domain_specific.h5 \
--epochs 5 \
--lr 1e-4 # Lower LR for fine-tuningPhase 1: Vertical Integration Complete one domain fully before adding others:
- ✓ Reaction-diffusion (complete)
- → Fluid dynamics (next)
- → Wave equations (future)
Rationale: Deep understanding of one domain before cross-domain comparison
Each domain follows independent CNO-trained components:
Component 1: Domain VQ-VAE
- Train on CNO ground truth features for the domain
- Discover domain-specific categories via per-family clustering
- Target: L_recon < 0.05 (achieved 0.006 in RD baseline)
Component 2: Domain MNO
- Architecture: Domain-appropriate (U-AFNO for parabolic, variants for others)
- Loss: Pure MSE against CNO ground truth (L_traj + L_ic)
- Target: L_traj < 1.0
Post-Training Validation
- Generate MNO rollouts from domain parameter space
- Verify VQ reconstruction quality on MNO outputs
- Compare MNO vs CNO feature distributions
After ≥2 domains complete:
Vocabulary Alignment:
# Compare codebook embeddings
rd_codebook = load_vqvae_codebook("checkpoints/vqvae/rd/")
fluids_codebook = load_vqvae_codebook("checkpoints/vqvae/fluids/")
alignment = compute_alignment(rd_codebook, fluids_codebook)
# Returns: correlation matrix, aligned pairs, semantic correspondencesTransfer Testing:
# Train NOA on Domain A, test on Domain B
noa = train_noa(domain="rd", tokens=rd_tokens)
transfer_acc = evaluate_noa(noa, domain="fluids", tokens=fluids_tokens)- Reaction-diffusion domain: Complete
- Multi-domain architecture: Research objective
- Vocabulary alignment: Awaiting second domain
- NOA Roadmap - Phase 0-3 implementation plan
- Architecture - System design
- Debugging Guide - NaN gradient troubleshooting
- VQ-VAE Training - Tokenizer training
Last Updated: 2026-01-07 Status: Phase 1 In Development (core training working, L_latent operational)