Skip to content

[support] YaRN + Gemma 2 27B finetuning: training instability at shorter context lengths and seq_length configuration confusion #3063

@yaoyu-33

Description

@yaoyu-33

Problem

A user is integrating YaRN (Yet Another RoPE extensioN) with the Gemma 2 27B model to support 32k context length for finetuning. They have manually added YaRN support and also created a Hugging Face conversion. Training runs, and the loss is initially stable; however, it becomes unstable after some training steps.

Key observations:

  1. Shorter context lengths cause instability — training at context lengths below 32k (e.g., 4k) is unstable, potentially indicating a bug in the configuration propagation.
  2. Longer context (32k) appears to work — loss is more stable with the full 32k context window.
  3. Confusion about where sequence length is configured — it is unclear whether YaRN is correctly recognizing/respecting the desired context length in the NeMo Megatron Bridge setup.

Environment

  • NeMo 25.11 container (for Megatron Bridge path)
  • NeMo 25.09 container with additional libraries installed (for HF path)
  • Megatron script configured for 32k tokens (same configuration previously used for 4k, with only seq_length adjusted)

Analysis — Potential Configuration Issues

1. seq_length propagation in Bridge

In Megatron Bridge, the sequence length flows through multiple layers:

  • GPTModelProvider.seq_length (default: 1024) → passed as max_sequence_length to MCoreGPTModel
  • Gemma2ModelProvider.seq_length inherits from GPTModelProvider, default is 8192
  • When using YaRN, yarn_original_max_position_embeddings defines the original (pre-extension) context window

The user must ensure seq_length on the provider is set to the target extended length (32k), not the original 8k. If using a recipe, the seq_length override must propagate correctly.

2. Gemma 2 Bridge and YaRN auto-detection

The base MegatronModelBridge.hf_config_to_provider_kwargs() method (model_bridge.py) does handle rope_scaling from HF config:

  • Detects rope_scaling.type == "yarn" and sets position_embedding_type = "yarn"
  • Maps factor, original_max_position_embeddings, beta_fast, beta_slow, mscale, mscale_all_dim to yarn_* provider fields

However, Gemma2Bridge.provider_bridge() calls super().provider_bridge() and then overwrites several provider fields (normalization, activation, etc.). The YaRN fields should survive since they are not explicitly overwritten, but this needs verification — especially for a manually-modified HF config where rope_scaling was added post-hoc.

3. Gemma2DotProductAttention limitations

The Gemma 2 provider uses a custom unfused dot-product attention (Gemma2DotProductAttention) to support attn_logit_softcapping. This attention implementation:

  • Does not support context parallelism (asserts context_parallel_size == 1)
  • Does not support packed sequences
  • Uses a sliding window (4096 tokens) on alternating layers

For 32k context, the sliding window interaction with YaRN's extended RoPE may need special handling. The alternating global/local attention pattern could also interact unexpectedly with longer sequences.

4. seq_length vs recipe configuration

When using run_recipe.py, the seq_length is typically set via:

model.seq_length=32768

But users coming from raw Megatron-LM scripts may set --seq-length in the arg parser (MLM compat path), which maps differently. The user should verify which path they are using and ensure the value reaches GPTModelProvider.seq_length.

Additional Issue: Data Format

The user also notes a data format mismatch:

  • HF finetuning script expects prompt / completion format (e.g., cjvt/GaMS-Nemotron-Chat dataset)
  • Megatron Bridge SFT expects the standard messages format (list of {"role": ..., "content": ...} dicts)

Users migrating from HF finetuning need to convert their datasets to the messages format.

Requested Actions

  1. Document where seq_length is defined and how it propagates through the YaRN + Bridge workflow
  2. Investigate whether there is a known issue with YaRN at shorter context lengths (particularly when seq_length < yarn_original_max_position_embeddings)
  3. Review the Gemma 2 provider's compatibility with YaRN — specifically the Gemma2DotProductAttention sliding window + softcapping interaction with extended RoPE
  4. Clarify whether Gemma 2 + YaRN is a tested/supported combination in Bridge

Reproduction

  1. Use Gemma 2 27B with a modified HF config that adds rope_scaling with type: "yarn" and factor: 4 (to extend 8k → 32k)
  2. Set seq_length=32768 in the Bridge recipe
  3. Run finetuning — observe stable loss initially, then instability
  4. Reduce seq_length to 4096 — observe worse instability

Affected area

area:model, area:training

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triageNew item needs classification and ownershipsupport

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions