Nemo Mbridge Resiliency

Resiliency features in Megatron Bridge including fault tolerance, straggler detection, in-process restart, preemption, and re-run state machine.

Published by @NVIDIA·0 agent reads / 30d·0 saves·

Resiliency

Stable docs: @docs/training/resiliency.md, @docs/training/checkpointing.md Card: @skills/nemo-mbridge-resiliency/card.yaml

Enablement

Fault tolerance (Slurm only)

Option 1: NeMo Run plugin (recommended)
from megatron.bridge.recipes.run_plugins import FaultTolerancePlugin
import nemo_run as run

task = run.Script(...)
run_plugins = [
    FaultTolerancePlugin(
        enable_ft_package=True,
        calc_ft_timeouts=True,
        num_in_job_restarts=3,
        num_job_retries_on_failure=2,
        initial_rank_heartbeat_timeout=1800,
        rank_heartbeat_timeout=300,
    )
]
run.run(task, plugins=run_plugins, executor=executor)
Plugin parameterDefaultDescription
num_in_job_restarts3Max restarts within same job
num_job_retries_on_failure2Max new job launches on failure
initial_rank_heartbeat_timeout1800First heartbeat timeout (seconds)
rank_heartbeat_timeout300Subsequent heartbeat timeout (seconds)
Option 2: Direct config + ft_launcher
from megatron.bridge.training.config import FaultToleranceConfig

cfg.ft = FaultToleranceConfig(
    enable_ft_package=True,
    calc_ft_timeouts=True,
    simulate_fault=False,
    simulated_fault_type="random",
)

Launch with ft_launcher (not torchrun):

export GROUP_RANK=0  # required for non-Slurm
ft_launcher \
    --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
    --nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
    --ft-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
    --ft-rank_out_of_section_timeout=300 \
    your_training_script.py
Config parameterDefaultDescription
enable_ft_packageFalseEnable fault tolerance
calc_ft_timeoutsFalseAuto-compute optimal timeouts
simulate_faultFalseEnable fault simulation for testing
simulated_fault_type"random""rank_hung", "rank_killed", or "random"
simulated_fault_rankNoneSpecific rank to fault (random if None)
simulated_fault_base_delay0Base delay before simulating fault

Section-based timeout monitoring covers setup, training steps, checkpointing, and out-of-section time independently. Timeouts are saved to ft_state.json for subsequent runs when calc_ft_timeouts=True.

NVRx straggler detection

from megatron.bridge.training.config import NVRxStragglerDetectionConfig

cfg.nvrx_straggler = NVRxStragglerDetectionConfig(
    enabled=True,
    report_time_interval=300.0,
    calc_relative_gpu_perf=True,
    calc_individual_gpu_perf=True,
    num_gpu_perf_scores_to_print=5,
    gpu_relative_perf_threshold=0.7,
    gpu_individual_perf_threshold=0.7,
    stop_if_detected=False,
    enable_logging=True,
)
ParameterDefaultDescription
enabledFalseEnable straggler detection
report_time_interval300.0Seconds between straggler checks
calc_relative_gpu_perfTrueCompare ranks against each other
calc_individual_gpu_perfTrueTrack per-rank degradation over time
gpu_relative_perf_threshold0.7Threshold for relative performance (0-1)
gpu_individual_perf_threshold0.7Threshold for individual performance (0-1)
stop_if_detectedFalseTerminate training on straggler
num_gpu_perf_scores_to_print5Number of best/worst scores to print
profiling_interval1Profiling interval for detector

Preemption

Plugin (Slurm)
from megatron.bridge.recipes.run_plugins import PreemptionPlugin

plugins = [
    PreemptionPlugin(
        preempt_time=60,
        enable_exit_handler=True,
        enable_exit_handler_for_data_loader=False,
    )
]
Plugin parameterDefaultDescription
preempt_time60Seconds before job limit to send signal
enable_exit_handlerTrueEnable signal handler in training
enable_exit_handler_for_data_loaderFalseEnable for dataloader workers
Direct config
import signal
cfg.train.exit_signal_handler = True
cfg.train.exit_signal = signal.SIGTERM
cfg.train.exit_signal_handler_for_dataloader = False

Re-run state machine (experimental)

from megatron.bridge.training.config import RerunStateMachineConfig

cfg.rerun_state_machine = RerunStateMachineConfig(
    rerun_mode="validate_results",
    check_for_nan_in_loss=True,
    check_for_spiky_loss=False,
    spiky_loss_factor=10.0,
)
ParameterDefaultDescription
rerun_mode"disabled""disabled", "validate_results", "report_determinism_stats"
check_for_nan_in_lossTrueCheck for NaN in loss
check_for_spiky_lossFalseCheck for unexpectedly large loss
spiky_loss_factor10.0Loss flagged if > factor * max observed (increase for large models)

Exit codes: 16 = resume to disambiguate, 17 = failed validation.

In-process restart (experimental)

from megatron.bridge.training.config import InProcessRestartConfig

cfg.inprocess_restart = InProcessRestartConfig(
    enabled=True,
    granularity="node",
    soft_timeout=60.0,
    hard_timeout=90.0,
)
ParameterDefaultDescription
enabledFalseEnable in-process restart
active_world_sizeNoneRanks executing workload (rest are warm reserves)
granularity"node""node" or "rank" restart granularity
max_iterationsNoneMax restart attempts (None = unlimited)
soft_timeout60.0Detect GIL-released hangs (seconds)
hard_timeout90.0Force-terminate hung ranks (seconds)
heartbeat_interval30.0Heartbeat interval (seconds)
heartbeat_timeout60.0Missing heartbeat timeout (seconds)
barrier_timeout120.0Distributed barrier timeout (seconds)
completion_timeout120.0Completion barrier timeout (seconds)
empty_cuda_cacheTrueClear CUDA cache during restart
max_rank_faultsNoneMax rank faults before terminating
monitor_process_logdirNoneDirectory for monitor logs

Required environment variables:

export TORCH_CPP_LOG_LEVEL=error
export TORCH_NCCL_RETHROW_CUDA_ERRORS=0
export NCCL_NVLS_ENABLE=0

The PyTorch NCCL watchdog timeout must exceed hard_timeout. NeMo-Run's Slurm Executor is not supported; launch directly with srun --kill-on-bad-exit=0.

Async checkpoint save

cfg.checkpoint.async_save = True
cfg.checkpoint.ckpt_format = "torch_dist"

Local checkpointing (NVRx)

cfg.checkpoint.non_persistent_local_ckpt_dir = "/local/scratch/ckpt"
cfg.checkpoint.non_persistent_local_ckpt_algo = "fully_parallel"

Code Anchors

Fault tolerance

  • Config: src/megatron/bridge/training/config.pyFaultToleranceConfig
  • Runtime: src/megatron/bridge/training/fault_tolerance.py
  • Plugin: src/megatron/bridge/recipes/run_plugins.pyFaultTolerancePlugin
  • Perf plugin: scripts/performance/nemo-mbridge-resiliency_plugins.py
  • Tests: tests/unit_tests/training/test_fault_tolerance.py
  • Example: examples/training_features/nemo-mbridge-resiliency/fault_tolerance/

Straggler detection

  • Config: src/megatron/bridge/training/config.pyNVRxStragglerDetectionConfig
  • Runtime: src/megatron/bridge/training/nvrx_straggler.py
  • Train loop: src/megatron/bridge/training/train.pycheck_nvrx_straggler_detection
  • Tests: tests/unit_tests/training/test_nvrx_straggler.py, tests/functional_tests/training/test_nvrx_straggler.py
  • Example: examples/training_features/nemo-mbridge-resiliency/straggler_detection/

In-process restart

  • Config: src/megatron/bridge/training/config.pyInProcessRestartConfig
  • Runtime: src/megatron/bridge/training/inprocess_restart.py
  • Entry point: src/megatron/bridge/training/pretrain.pymaybe_wrap_for_inprocess_restart
  • Tests: tests/unit_tests/training/test_inprocess_restart.py, tests/functional_tests/training/test_inprocess_restart.py

Preemption

  • Plugin: src/megatron/bridge/recipes/run_plugins.pyPreemptionPlugin
  • Signal handler: src/megatron/bridge/training/utils/sig_utils.py
  • Tests: tests/unit_tests/recipes/test_run_plugins.py

Re-run state machine

  • Config: src/megatron/bridge/training/config.pyRerunStateMachineConfig
  • Init: src/megatron/bridge/training/initialize.pyinit_rerun_state

Checkpointing

  • Async save: src/megatron/bridge/training/checkpointing.pyschedule_async_save
  • Local ckpt: src/megatron/bridge/training/checkpointing.pyLocalCheckpointManager
  • Tests: tests/functional_tests/training/test_local_checkpointing.py

Pitfalls

  1. ft_launcher, not torchrun: Direct FaultToleranceConfig requires ft_launcher. Using torchrun silently disables FT. For non-Slurm, set GROUP_RANK=0.

  2. Async save requires torch_dist: async_save=True only works with ckpt_format="torch_dist". Other formats silently fail or error.

  3. IPR + NeMo-Run: In-process restart is not compatible with NeMo-Run or Slurm preemption plugins. Requires specific PyTorch/NCCL versions and env vars.

  4. NVRx vs legacy straggler: Two detectors exist. Use NVRx (nvrx_straggler); do not enable both.

  5. stop_if_detected default: NVRx logs but does not stop training by default. Set stop_if_detected=True for automatic termination.

  6. NCCL watchdog vs hard_timeout: For IPR, NCCL watchdog timeout must exceed hard_timeout or PyTorch kills the process before recovery.

  7. Rerun state machine is alpha: Use check_for_nan_in_loss=True for NaN detection, but don't rely on full rerun workflows yet.

Verification

Fault tolerance

./examples/training_features/nemo-mbridge-resiliency/fault_tolerance/run_fault_tolerance.sh
./examples/training_features/nemo-mbridge-resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault

Look for [FaultTolerance] / [RankMonitorServer] log lines with section timeouts. Simulated fault should trigger restart from checkpoint.

Straggler detection

uv run python -m torch.distributed.run --nproc_per_node=2 \
    examples/training_features/nemo-mbridge-resiliency/straggler_detection/straggler_detection_example.py

Look for GPU relative performance and GPU individual performance reports with per-rank scores.

Async checkpoint

Look for Scheduling async checkpoint save in logs. Training iterations should continue while checkpoint files are being written.

In-process restart

pytest tests/functional_tests/training/test_inprocess_restart.py -v

Requires compatible PyTorch/NCCL versions.

Bundled with this artifact

5 files

Reference files that ship alongside this artifact. Agents pull these in only when the task needs them.

More on the bench

SKILL0

Whisper

OpenAI's general-purpose speech recognition model. Supports 99 languages, transcription, translation to English, and language identification. Six model sizes from tiny (39M params) to large (1550M params). Use for speech-to-text, podcast transcription, or multilingual audio processing. Best for robust, multilingual ASR.

data-science-ml+2
0
SKILL0

Guidance

Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework

ai-prompt-engineering+2
0
SKILL0

Pinecone

Managed vector database for production AI applications. Fully managed, auto-scaling, with hybrid search (dense + sparse), metadata filtering, and namespaces. Low latency (<100ms p95). Use for production RAG, recommendation systems, or semantic search at scale. Best for serverless, managed infrastructure.

data-science-ml+2
0