Skip to content

Commit d1cf66f

Browse files
committed
fix config for qwen
1 parent a52db71 commit d1cf66f

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

onnx_diagnostic/ci_models/export_qwen25_vl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def main(
280280
).eval()
281281
data = dict(model=model)
282282
config = model.config
283+
if not hasattr(config, "bos_token_id") or not config.bos_token_id:
284+
config.bos_token_id = 151643
285+
if not hasattr(config, "eos_token_id") or not config.eos_token_id:
286+
config.eos_token_id = 151645
283287
else:
284288
print("-- random model")
285289
data = get_untrained_model(model_id, second_input=second_input, verbose=1)

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,21 @@ def qwen_sdpa_attention(
256256
return attn_output
257257

258258
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259+
import onnx_ir
260+
259261
first_float_tensor = next(
260262
a
261263
for a in args
262-
if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
264+
if a is not None
265+
and a.dtype
266+
in {
267+
torch.float16,
268+
torch.float32,
269+
torch.bfloat16,
270+
onnx_ir.DataType.BFLOAT16,
271+
onnx_ir.DataType.FLOAT16,
272+
onnx_ir.DataType.FLOAT,
273+
}
263274
)
264275
dtype = first_float_tensor.dtype
265276
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()

0 commit comments

Comments
 (0)