-
Notifications
You must be signed in to change notification settings - Fork 644
Description
Bug description
Enabling validation with flash-attention enabled leads to a dynamo error:
File "torchtitan/components/validate.py", line 158, in validate
predictions = model_parts[0](inputs)
File "torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
File "venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
File "torchtitan/models/qwen3/model/model.py", line 536, in forward
h = layer(h, self.rope_cache, attention_masks)
File "venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 441, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
File ".venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 939, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: raised exception AssertionError([ConstantVariable(str: 'None')])
I traced the bug to the validation step: it isn’t passing attention_masks to the model, whereas the training step does. I tested this by generating the attention_masks the same way the trainer does and providing them to the model’s forward pass. Everything worked as expected. I’ll prepare a PR shortly.
Versions
Torch verision
2.10.0.dev20251210+cu126
Experiment toml
[job]
dump_folder = "data/experiments/baseline"
description = "Qwen 3 4B training with FlashAttention"
[profiling]
enable_profiling = false
save_traces_folder = "data/experiments/baseline/profile_trace"
profile_freq = 100
[metrics]
log_freq = 1
enable_wandb = true
enable_tensorboard = false
save_tb_folder = "tensorboard"
[model]
name = "qwen3"
flavor = "flash4B" # custom flavor with varlen
[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8
[lr_scheduler]
warmup_steps = 2000 # lr scheduler warm up, 20% total steps
decay_type = "cosine"
min_lr_factor = 3e-5
[training]
local_batch_size = 2
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 30
dataset = "c4"
dtype = "bfloat16"
[validation]
enable = true
dataset = "synth_validation"
seq_len = 8192
freq = 10
steps = -1
[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
context_parallel_degree = 1
[checkpoint]
enable = true
folder = "checkpoint"
interval = 500
last_save_model_only = false
export_dtype = "bfloat16"
async_mode = "async"
keep_latest_k = 10
[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
[compile]
enable=true
components = ["model", "loss"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]qwen3 flavor
"flash4B": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=4096,
head_dim=128,
dim=2560,
n_layers=36,
n_heads=32,
n_kv_heads=8,
qk_norm=True,
hidden_dim=9728,
rope_theta=1000000,
enable_weight_tying=True,
attn_type="varlen",
attn_mask_type="block_causal",
)Command
PYTHONPATH=./ uv run torchrun --nnodes 1 --nproc-per-node 4 torchtitan/train.py --job.config-file path/to/config.toml --model.tokenizer-path path/to/tokenizerwwwjn
Metadata
Metadata
Assignees
Labels
No labels