diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 87759d59..f840f7c9 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -18,7 +18,10 @@ from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5 +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + patch_qwen2_5, + patch_funnel, +) from onnx_diagnostic.export.api import to_onnx @@ -787,6 +790,42 @@ def test_plug_multi_head_attention_qwen25_loopa24_float32(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) self.assertLess(results.diffs[0]["abs"], 1e-5) + @unittest.skipIf(not patch_funnel, "Funnel not part of this transformers") + def test_model_funnel(self): + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + patched_FunnelAttentionStructure, + patched_FunnelRelMultiheadAttention, + ) + + pos = torch.tensor([0, 4, 5, 8], dtype=torch.long) + stride = 2 + config = transformers.models.funnel.modeling_funnel.FunnelConfig() + original = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure(config) + patched = patched_FunnelAttentionStructure() + self.assertEqualArray( + original.relative_pos(pos, stride=stride), patched.relative_pos(pos, stride=stride) + ) + + rmha = transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention( + config, 2 + ) + patched = patched_FunnelRelMultiheadAttention() + patched.config = config + for att in ["block_index", "r_r_bias", "scale", "r_kernel"]: + setattr(patched, att, getattr(rmha, att)) + inputs = dict( + position_embeds=[ + [torch.rand((24, 768)), None], + [torch.rand((12, 768)), torch.rand((24, 768))], + [torch.rand((6, 768)), torch.rand((12, 768))], + ], + q_head=torch.rand((2, 12, 12, 64)), + context_len=12, + ) + expected = rmha.relative_positional_attention(**inputs) + got = patched.relative_positional_attention(**inputs) + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ci_models/export_qwen25_vl.py b/onnx_diagnostic/ci_models/export_qwen25_vl.py index 028176f7..42a022bf 100644 --- a/onnx_diagnostic/ci_models/export_qwen25_vl.py +++ b/onnx_diagnostic/ci_models/export_qwen25_vl.py @@ -59,6 +59,7 @@ import os import sys import time +import warnings from typing import Any, Dict, List, Tuple from .ci_helpers import ( check_for_discrepancies_and_log_everything_into_a_json_file, @@ -301,7 +302,11 @@ def main( print(f"-- config._attn_implementation={model.config._attn_implementation}") print(f"-- model.dtype={model.dtype}") print(f"-- model.device={model.device}") - processor = AutoProcessor.from_pretrained(model_id, use_fast=True) + try: + processor = AutoProcessor.from_pretrained(model_id, use_fast=True) + except OSError as e: + warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0) + return print(f"-- processor={type(processor)}") export_inputs, other_inputs = None, None diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index cfd142dc..78678cae 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -13,6 +13,10 @@ __TASK__ = "image-text-to-text" +def should_have_vision_config(config): + return config.architectures != ["FuyuForCausalLM"] + + def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" kwargs: Dict[str, Any] = {} @@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "hidden_size", "pad_token_id", ) - check_hasattr(config, "vision_config", ("image_token_index", "image_token_id")) + if should_have_vision_config(config): + check_hasattr(config, "vision_config", ("image_token_index", "image_token_id")) text_config = True else: check_hasattr( @@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "vision_config", ) text_config = False - check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels")) + if should_have_vision_config(config): + check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels")) kwargs = dict( head_dim=( 16 @@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: ), width=( 224 - if config is None or not hasattr(config.vision_config, "image_size") + if config is None + or not should_have_vision_config(config) + or not hasattr(config.vision_config, "image_size") else config.vision_config.image_size ), height=( 224 - if config is None or not hasattr(config.vision_config, "image_size") + if config is None + or not should_have_vision_config(config) + or not hasattr(config.vision_config, "image_size") else config.vision_config.image_size ), num_channels=( 3 - if config is None + if config is None or not should_have_vision_config(config) else _pick(config.vision_config, "num_channels", "in_chans", "in_channels") ), pad_token_id=( diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index 365695b0..4f829bb6 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: config.num_decoder_layers = min(config.num_decoder_layers, 2) if hasattr(config, "num_hidden_layers"): config.num_hidden_layers = min(config.num_hidden_layers, nhl()) + if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"): + default_layer_types = [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ] + config.encoder.num_hidden_layers = 4 + config.encoder.layer_types = ( + default_layer_types if config is None else config.encoder.layer_types[:4] + ) + config.decoder.num_hidden_layers = 4 + config.decoder.layer_types = ( + default_layer_types if config is None else config.decoder.layer_types[:4] + ) + update_config(config, kwargs) return kwargs @@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: If the configuration is None, the function selects typical dimensions. """ + path = 1 if config is not None: - check_hasattr( - config, - "vocab_size", - "hidden_size", - "num_attention_heads", - ("num_hidden_layers", "num_layers"), - ("n_positions", "d_model"), - ( - "num_key_value_heads", - "num_heads", - ("decoder_attention_heads", "encoder_attention_heads"), - ), - ) - # exceptions = { - # "PLBartForConditionalGeneration": ( - # lambda c: c.encoder_attention_heads + c.decoder_attention_heads - # ) - # } - kwargs = dict( - batch_size=2, - sequence_length=30, - sequence_length2=3, - head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"), - head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"), - dummy_max_token_id=31999 if config is None else config.vocab_size - 1, - num_hidden_layers=( - 8 if config is None else _pick(config, "num_hidden_layers", "num_layers") - ), - num_key_value_heads_encoder=( - 16 - if config is None - else _pick( + if hasattr(config, "num_attention_heads"): + check_hasattr( config, - "encoder_attention_heads", - "num_key_value_heads", - "num_heads", + "vocab_size", + "hidden_size", + "num_attention_heads", + ("num_hidden_layers", "num_layers"), + ("n_positions", "d_model"), + ( + "num_key_value_heads", + "num_heads", + ("decoder_attention_heads", "encoder_attention_heads"), + ), ) - ), - num_key_value_heads_decoder=( - 16 - if config is None - else _pick( - config, - "decoder_attention_heads", - "num_key_value_heads", - "num_heads", - ) - ), - encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"), - ) + else: + check_hasattr(config, "encoder", "decoder") + path = 2 + + if path == 1: + kwargs = dict( + batch_size=2, + sequence_length=30, + sequence_length2=3, + head_dim_encoder=( + 16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim") + ), + head_dim_decoder=( + 16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim") + ), + dummy_max_token_id=31999 if config is None else config.vocab_size - 1, + num_hidden_layers=( + 8 if config is None else _pick(config, "num_hidden_layers", "num_layers") + ), + num_key_value_heads_encoder=( + 16 + if config is None + else _pick( + config, + "encoder_attention_heads", + "num_key_value_heads", + "num_heads", + ) + ), + num_key_value_heads_decoder=( + 16 + if config is None + else _pick( + config, + "decoder_attention_heads", + "num_key_value_heads", + "num_heads", + ) + ), + encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"), + ) + else: + kwargs = dict( + batch_size=2, + sequence_length=30, + sequence_length2=3, + dummy_max_token_id=config.encoder.vocab_size - 1, + num_key_value_heads_encoder=config.encoder.num_key_value_heads, + num_key_value_heads_decoder=config.decoder.num_key_value_heads, + num_hidden_layers=len(config.encoder.layer_types), + head_dim_encoder=config.encoder.head_dim, + head_dim_decoder=config.decoder.head_dim, + encoder_dim=256, + ) + return kwargs, get_inputs diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 1ab8deba..b9336879 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: state_size=8 if config is None else getattr(config, "state_size", None), conv_kernel=4 if config is None else getattr(config, "conv_kernel", None), ) + elif config.__class__.__name__ == "FunnelConfig": + # does not support num_hidden_layers + kwargs = dict() else: kwargs = dict( head_dim=getattr( diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py new file mode 100644 index 00000000..e1043b96 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py @@ -0,0 +1,80 @@ +import torch + +try: + import transformers.models.funnel.modeling_funnel + + patch_funnel = True +except ImportError: + patch_funnel = False + +if patch_funnel: + from transformers.models.funnel.modeling_funnel import _relative_shift_gather + + class patched_FunnelAttentionStructure(torch.nn.Module): + _PATCHES_ = ["relative_pos"] + _PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure + + def relative_pos( + self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1 + ) -> torch.Tensor: + if pooled_pos is None: + pooled_pos = pos + ref_point = pooled_pos[0] - pos[0] + # PATCHED + num_remove = shift * pooled_pos.shape[0] + max_dist = ref_point + num_remove * stride + min_dist = pooled_pos[0] - pos[-1] + return torch.arange( + max_dist.to(torch.long), + (min_dist - 1).to(torch.long), + torch.tensor(-stride, dtype=torch.long), + dtype=torch.long, + device=pos.device, + ) + + class patched_FunnelRelMultiheadAttention(torch.nn.Module): + _PATCHES_ = ["relative_positional_attention"] + _PATCHED_CLASS_ = ( + transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention + ) + + def relative_positional_attention( + self, position_embeds, q_head, context_len, cls_mask=None + ): + """Relative attention score for the positional encodings""" + # q_head has shape batch_size x sea_len x n_head x d_head + if self.config.attention_type == "factorized": + phi, pi, psi, omega = position_embeds + # Shape n_head x d_head + u = self.r_r_bias * self.scale + # Shape d_model x n_head x d_head + w_r = self.r_kernel + + # Shape batch_size x sea_len x n_head x d_model + q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r) + q_r_attention_1 = q_r_attention * phi[:, None] + q_r_attention_2 = q_r_attention * pi[:, None] + + # Shape batch_size x n_head x seq_len x context_len + positional_attn = torch.einsum( + "bind,jd->bnij", q_r_attention_1, psi + ) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega) + else: + shift = 2 if q_head.shape[1] != context_len else 1 + r = position_embeds[self.block_index][shift - 1] + # Shape n_head x d_head + v = self.r_r_bias * self.scale + # Shape d_model x n_head x d_head + w_r = self.r_kernel + + # Shape max_rel_len x n_head x d_model + r_head = torch.einsum("td,dnh->tnh", r, w_r) + # Shape batch_size x n_head x seq_len x max_rel_len + positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head) + # Shape batch_size x n_head x seq_len x context_len + positional_attn = _relative_shift_gather(positional_attn, context_len, shift) + + if cls_mask is not None: + # PATCHED + positional_attn = positional_attn * cls_mask + return positional_attn diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index faa4bf9f..07417855 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1,29 +1,37 @@ # transformers from typing import List from .patch_helper import _has_transformers - from ._patch_transformers_attention import ( patched_sdpa_attention_forward, patched_model_bart_eager_attention_forward, patched_modeling_marian_eager_attention_forward, ) +from ._patch_transformers_generation_mixin import patched_GenerationMixin +from ._patch_transformers_causal_mask import patched_AttentionMaskConverter +from ._patch_transformers_rotary_embedding import ( + patched__compute_dynamic_ntk_parameters, + patched_dynamic_rope_update, + patched_GemmaRotaryEmbedding, + patched_LlamaRotaryEmbedding, + patched_MistralRotaryEmbedding, + patched_MixtralRotaryEmbedding, + patched_PhiRotaryEmbedding, +) +from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention +from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder + +# transformers dependent patches from ._patch_transformers_cache_utils import patch_parse_processor_args if patch_parse_processor_args: from ._patch_transformers_cache_utils import patched_parse_processor_args - -from ._patch_transformers_causal_mask import patched_AttentionMaskConverter - from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache if patch_DynamicLayer: from ._patch_transformers_dynamic_cache import patched_DynamicLayer if patch_DynamicCache: from ._patch_transformers_dynamic_cache import patched_DynamicCache - -from ._patch_transformers_generation_mixin import patched_GenerationMixin - from ._patch_transformers_masking_utils import patch_masking_utils if patch_masking_utils: @@ -33,15 +41,7 @@ patched_sdpa_mask_recent_torch, ) -from ._patch_transformers_rotary_embedding import ( - patched__compute_dynamic_ntk_parameters, - patched_dynamic_rope_update, - patched_GemmaRotaryEmbedding, - patched_LlamaRotaryEmbedding, - patched_MistralRotaryEmbedding, - patched_MixtralRotaryEmbedding, - patched_PhiRotaryEmbedding, -) +# transformers models dependent patches if _has_transformers("4.51"): from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding @@ -54,16 +54,11 @@ if _has_transformers("4.53"): from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding -# Models - from ._patch_transformers_gemma3 import patch_gemma3 if patch_gemma3: from ._patch_transformers_gemma3 import patched_Gemma3Model -from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention - - from ._patch_transformers_qwen2 import patch_qwen2 if patch_qwen2: @@ -80,14 +75,17 @@ patched_Qwen2_5_VLModel, PLUGS as PLUGS_Qwen25, ) - from ._patch_transformers_qwen3 import patch_qwen3 if patch_qwen3: from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock +from ._patch_transformers_funnel import patch_funnel - -from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder +if patch_funnel: + from ._patch_transformers_funnel import ( + patched_FunnelAttentionStructure, + patched_FunnelRelMultiheadAttention, + ) def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821 diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index 445f6c02..71f7a376 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -184,7 +184,18 @@ def _trygetattr(config, attname): return None +def rewrite_architecture_name(name: Optional[str]) -> Optional[str]: + if name == "ConditionalDETRForObjectDetection": + return "ConditionalDetrForObjectDetection" + return name + + def architecture_from_config(config) -> Optional[str]: + """Guesses the architecture (class) of the model described by this config.""" + return rewrite_architecture_name(_architecture_from_config(config)) + + +def _architecture_from_config(config) -> Optional[str]: """Guesses the architecture (class) of the model described by this config.""" if isinstance(config, dict): if "_class_name" in config: diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index 299c37eb..054f8ede 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -5,7 +5,10 @@ __date__ = "2025-06-21" -__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)} +__data_arch_values__ = { + "ConditionalDETRForObjectDetection": dict(image_size=224), + "ResNetForImageClassification": dict(image_size=224), +} __data_arch__ = textwrap.dedent( """ @@ -32,6 +35,7 @@ ConvNextV2Model,image-feature-extraction CosmosTransformer3DModel,image-to-video CvtModel,feature-extraction + ClvpModelForConditionalGeneration,audio-feature-extraction DPTModel,image-feature-extraction Data2VecAudioModel,feature-extraction Data2VecTextModel,feature-extraction @@ -49,6 +53,8 @@ ElectraModel,feature-extraction EsmModel,feature-extraction FalconMambaForCausalLM,text-generation + FunnelBaseModel,feature-extraction + FuyuForCausalLM,image-text-to-text GLPNModel,image-feature-extraction GPT2LMHeadModel,text-generation GPTBigCodeModel,feature-extraction @@ -63,6 +69,7 @@ Glm4vMoeForConditionalGeneration,image-text-to-text GraniteForCausalLM,text-generation GroupViTModel,feature-extraction + HeliumForCausalLM,text-generation HieraForImageClassification,image-classification HubertModel,feature-extraction IBertModel,feature-extraction @@ -136,6 +143,7 @@ SwinModel,image-feature-extraction Swinv2Model,image-feature-extraction T5ForConditionalGeneration,text2text-generation + T5GemmaForConditionalGeneration,text2text-generation TableTransformerModel,image-feature-extraction TableTransformerForObjectDetection,object-detection UNet2DConditionModel,text-to-image