Context
KempnerForge detects training instability (NaN/Inf loss) but only halts the run after a consecutive-NaN threshold. There is no automatic recovery path, so every instability event requires manual intervention. For long-running production jobs that's expensive.
Current behavior:
NaNDetector.check_loss in kempnerforge/resilience/health.py syncs a NaN flag across ranks via a small all_reduce before returning.
- On NaN, the training loop in
scripts/train.py zeros grads, skips the optimizer step, and continues. If consecutive_nans >= max_consecutive, it logs "Too many consecutive NaNs — stopping" and breaks out of the loop.
- The detector is constructed in
scripts/train.py with hardcoded action="warn" and max_consecutive=10 — not exposed via TOML.
NaNDetector.should_rollback is a property flag, not an automated action; the loop just stops when it trips.
NaNDetector.check_gradients helper exists but is not called from the training loop (noted in docs/resilience/nan-detection.md).
- NCCL health check defaults to disabled:
nccl_health_check_interval: int = 0 in TrainConfig.
Items to consider
- Automatic checkpoint rollback: On
should_rollback, drop to the last known-good checkpoint (or N steps back) and resume instead of breaking.
- LR reduction on rollback: Temporarily reduce LR (e.g., ×0.5) after rollback and ramp back over N steps.
- Gradient anomaly detection in the hot path: Wire
NaNDetector.check_gradients (or a cheaper grad-norm spike check) into the training loop so explosions are caught before they reach loss.
- OOM recovery with batch size reduction: Catch CUDA OOM, reduce micro-batch size or grad-accum for one step, retry.
- Configurable recovery policy via TOML: Expose NaN action (
warn/skip/raise), max_consecutive, rollback depth, max retries, LR reduction factor.
- NCCL health check default: Consider a non-zero default for
nccl_health_check_interval so production runs detect hung collectives without opt-in.
- Structured recovery event log: Log each recovery action (rollback step, LR change, skip reason) for post-mortem.
NaNState.nan_steps already captures NaN step indices but not recovery actions.
on_instability training hook: TrainingHook in kempnerforge/training/hooks.py currently exposes on_train_begin, on_step_end, on_eval_end, on_checkpoint_save, and on_train_end. Add one more for instability so researchers can plug in custom recovery.
Priority
Low for now. Detect-and-stop is safe and predictable; automatic recovery adds complexity and risk of silently training on bad state. Revisit when multi-day jobs make operator intervention cost too high.
Context
KempnerForge detects training instability (NaN/Inf loss) but only halts the run after a consecutive-NaN threshold. There is no automatic recovery path, so every instability event requires manual intervention. For long-running production jobs that's expensive.
Current behavior:
NaNDetector.check_lossinkempnerforge/resilience/health.pysyncs a NaN flag across ranks via a smallall_reducebefore returning.scripts/train.pyzeros grads, skips the optimizer step, and continues. Ifconsecutive_nans >= max_consecutive, it logs"Too many consecutive NaNs — stopping"and breaks out of the loop.scripts/train.pywith hardcodedaction="warn"andmax_consecutive=10— not exposed via TOML.NaNDetector.should_rollbackis a property flag, not an automated action; the loop just stops when it trips.NaNDetector.check_gradientshelper exists but is not called from the training loop (noted indocs/resilience/nan-detection.md).nccl_health_check_interval: int = 0inTrainConfig.Items to consider
should_rollback, drop to the last known-good checkpoint (or N steps back) and resume instead of breaking.NaNDetector.check_gradients(or a cheaper grad-norm spike check) into the training loop so explosions are caught before they reach loss.warn/skip/raise),max_consecutive, rollback depth, max retries, LR reduction factor.nccl_health_check_intervalso production runs detect hung collectives without opt-in.NaNState.nan_stepsalready captures NaN step indices but not recovery actions.on_instabilitytraining hook:TrainingHookinkempnerforge/training/hooks.pycurrently exposeson_train_begin,on_step_end,on_eval_end,on_checkpoint_save, andon_train_end. Add one more for instability so researchers can plug in custom recovery.Priority
Low for now. Detect-and-stop is safe and predictable; automatic recovery adds complexity and risk of silently training on bad state. Revisit when multi-day jobs make operator intervention cost too high.