Skip to content

Automatic recovery from training instability #2

@mmshad

Description

@mmshad

Context

KempnerForge detects training instability (NaN/Inf loss) but only halts the run after a consecutive-NaN threshold. There is no automatic recovery path, so every instability event requires manual intervention. For long-running production jobs that's expensive.

Current behavior:

  • NaNDetector.check_loss in kempnerforge/resilience/health.py syncs a NaN flag across ranks via a small all_reduce before returning.
  • On NaN, the training loop in scripts/train.py zeros grads, skips the optimizer step, and continues. If consecutive_nans >= max_consecutive, it logs "Too many consecutive NaNs — stopping" and breaks out of the loop.
  • The detector is constructed in scripts/train.py with hardcoded action="warn" and max_consecutive=10 — not exposed via TOML.
  • NaNDetector.should_rollback is a property flag, not an automated action; the loop just stops when it trips.
  • NaNDetector.check_gradients helper exists but is not called from the training loop (noted in docs/resilience/nan-detection.md).
  • NCCL health check defaults to disabled: nccl_health_check_interval: int = 0 in TrainConfig.

Items to consider

  • Automatic checkpoint rollback: On should_rollback, drop to the last known-good checkpoint (or N steps back) and resume instead of breaking.
  • LR reduction on rollback: Temporarily reduce LR (e.g., ×0.5) after rollback and ramp back over N steps.
  • Gradient anomaly detection in the hot path: Wire NaNDetector.check_gradients (or a cheaper grad-norm spike check) into the training loop so explosions are caught before they reach loss.
  • OOM recovery with batch size reduction: Catch CUDA OOM, reduce micro-batch size or grad-accum for one step, retry.
  • Configurable recovery policy via TOML: Expose NaN action (warn/skip/raise), max_consecutive, rollback depth, max retries, LR reduction factor.
  • NCCL health check default: Consider a non-zero default for nccl_health_check_interval so production runs detect hung collectives without opt-in.
  • Structured recovery event log: Log each recovery action (rollback step, LR change, skip reason) for post-mortem. NaNState.nan_steps already captures NaN step indices but not recovery actions.
  • on_instability training hook: TrainingHook in kempnerforge/training/hooks.py currently exposes on_train_begin, on_step_end, on_eval_end, on_checkpoint_save, and on_train_end. Add one more for instability so researchers can plug in custom recovery.

Priority

Low for now. Detect-and-stop is safe and predictable; automatic recovery adds complexity and risk of silently training on bad state. Revisit when multi-day jobs make operator intervention cost too high.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions