Skip to content

Validation breaking with FlashAttention #2140

@francesco-bertolotti

Description

@francesco-bertolotti

Bug description

Enabling validation with flash-attention enabled leads to a dynamo error:

    File "torchtitan/components/validate.py", line 158, in validate
      predictions = model_parts[0](inputs)
    File "torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
      return inner()
    File "venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1830, in inner
      result = forward_call(*args, **kwargs)
    File "torchtitan/models/qwen3/model/model.py", line 536, in forward
      h = layer(h, self.rope_cache, attention_masks)
    File "venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 441, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
      return inner()
    File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1830, in inner
      result = forward_call(*args, **kwargs)
    File ".venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 939, in compile_wrapper
      raise e.with_traceback(None) from e.__cause__  # User compiler error
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.Unsupported: Observed exception
    Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
    Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
    Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

    Developer debug context: raised exception AssertionError([ConstantVariable(str: 'None')])

I traced the bug to the validation step: it isn’t passing attention_masks to the model, whereas the training step does. I tested this by generating the attention_masks the same way the trainer does and providing them to the model’s forward pass. Everything worked as expected. I’ll prepare a PR shortly.

Versions

Torch verision

2.10.0.dev20251210+cu126

Experiment toml

[job]
dump_folder = "data/experiments/baseline"
description = "Qwen 3 4B training with FlashAttention"

[profiling]
enable_profiling = false
save_traces_folder = "data/experiments/baseline/profile_trace"
profile_freq = 100

[metrics]
log_freq = 1
enable_wandb = true
enable_tensorboard = false
save_tb_folder = "tensorboard"

[model]
name = "qwen3"
flavor = "flash4B" # custom flavor with varlen

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2000  # lr scheduler warm up, 20% total steps
decay_type = "cosine"
min_lr_factor = 3e-5

[training]
local_batch_size = 2
seq_len = 8192
max_norm = 1.0  # grad norm clipping
steps = 30
dataset = "c4"
dtype = "bfloat16"

[validation]
enable = true
dataset = "synth_validation"
seq_len = 8192
freq = 10
steps = -1

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
last_save_model_only = false
export_dtype = "bfloat16"
async_mode = "async"
keep_latest_k = 10

[activation_checkpoint]
mode = "selective"  # ["none", "selective", "full"]
selective_ac_option = "op"  # "int" = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

qwen3 flavor

    "flash4B": Qwen3ModelArgs(
        vocab_size=151936,
        max_seq_len=4096,
        head_dim=128,
        dim=2560,
        n_layers=36,
        n_heads=32,
        n_kv_heads=8,
        qk_norm=True,
        hidden_dim=9728,
        rope_theta=1000000,
        enable_weight_tying=True,
        attn_type="varlen",
        attn_mask_type="block_causal",
    )

Command

PYTHONPATH=./ uv run torchrun --nnodes 1 --nproc-per-node 4 torchtitan/train.py --job.config-file path/to/config.toml --model.tokenizer-path path/to/tokenizer

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions