Skip to content

Commit b37d6d1

Browse files
committed
add patches for Funnel
1 parent b6ac5b3 commit b37d6d1

File tree

5 files changed

+148
-25
lines changed

5 files changed

+148
-25
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
1919
from onnx_diagnostic.torch_export_patches import torch_export_patches
2020
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
21-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
21+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
22+
patch_qwen2_5,
23+
patch_funnel,
24+
)
2225
from onnx_diagnostic.export.api import to_onnx
2326

2427

@@ -787,6 +790,42 @@ def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
787790
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
788791
self.assertLess(results.diffs[0]["abs"], 1e-5)
789792

793+
@unittest.skipIf(not patch_funnel, "Funnel not part of this transformers")
794+
def test_model_funnel(self):
795+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
796+
patched_FunnelAttentionStructure,
797+
patched_FunnelRelMultiheadAttention,
798+
)
799+
800+
pos = torch.tensor([0, 4, 5, 8], dtype=torch.long)
801+
stride = 2
802+
config = transformers.models.funnel.modeling_funnel.FunnelConfig()
803+
original = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure(config)
804+
patched = patched_FunnelAttentionStructure()
805+
self.assertEqualArray(
806+
original.relative_pos(pos, stride=stride), patched.relative_pos(pos, stride=stride)
807+
)
808+
809+
rmha = transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention(
810+
config, 2
811+
)
812+
patched = patched_FunnelRelMultiheadAttention()
813+
patched.config = config
814+
for att in ["block_index", "r_r_bias", "scale", "r_kernel"]:
815+
setattr(patched, att, getattr(rmha, att))
816+
inputs = dict(
817+
position_embeds=[
818+
[torch.rand((24, 768)), None],
819+
[torch.rand((12, 768)), torch.rand((24, 768))],
820+
[torch.rand((6, 768)), torch.rand((12, 768))],
821+
],
822+
q_head=torch.rand((2, 12, 12, 64)),
823+
context_len=12,
824+
)
825+
expected = rmha.relative_positional_attention(**inputs)
826+
got = patched.relative_positional_attention(**inputs)
827+
self.assertEqualArray(expected, got)
828+
790829

791830
if __name__ == "__main__":
792831
unittest.main(verbosity=2)

onnx_diagnostic/tasks/text_generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
4040
state_size=8 if config is None else getattr(config, "state_size", None),
4141
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
4242
)
43+
elif config.__class__.__name__ == "FunnelConfig":
44+
# does not support num_hidden_layers
45+
kwargs = dict()
4346
else:
4447
kwargs = dict(
4548
head_dim=getattr(
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
3+
try:
4+
import transformers.models.funnel.modeling_funnel
5+
6+
patch_funnel = True
7+
except ImportError:
8+
patch_funnel = False
9+
10+
if patch_funnel:
11+
from transformers.models.funnel.modeling_funnel import _relative_shift_gather
12+
13+
class patched_FunnelAttentionStructure(torch.nn.Module):
14+
_PATCHES_ = ["relative_pos"]
15+
_PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
16+
17+
def relative_pos(
18+
self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
19+
) -> torch.Tensor:
20+
if pooled_pos is None:
21+
pooled_pos = pos
22+
ref_point = pooled_pos[0] - pos[0]
23+
# PATCHED
24+
num_remove = shift * pooled_pos.shape[0]
25+
max_dist = ref_point + num_remove * stride
26+
min_dist = pooled_pos[0] - pos[-1]
27+
return torch.arange(
28+
max_dist.to(torch.long),
29+
(min_dist - 1).to(torch.long),
30+
torch.tensor(-stride, dtype=torch.long),
31+
dtype=torch.long,
32+
device=pos.device,
33+
)
34+
35+
class patched_FunnelRelMultiheadAttention(torch.nn.Module):
36+
_PATCHES_ = ["relative_positional_attention"]
37+
_PATCHED_CLASS_ = (
38+
transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
39+
)
40+
41+
def relative_positional_attention(
42+
self, position_embeds, q_head, context_len, cls_mask=None
43+
):
44+
"""Relative attention score for the positional encodings"""
45+
# q_head has shape batch_size x sea_len x n_head x d_head
46+
if self.config.attention_type == "factorized":
47+
phi, pi, psi, omega = position_embeds
48+
# Shape n_head x d_head
49+
u = self.r_r_bias * self.scale
50+
# Shape d_model x n_head x d_head
51+
w_r = self.r_kernel
52+
53+
# Shape batch_size x sea_len x n_head x d_model
54+
q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
55+
q_r_attention_1 = q_r_attention * phi[:, None]
56+
q_r_attention_2 = q_r_attention * pi[:, None]
57+
58+
# Shape batch_size x n_head x seq_len x context_len
59+
positional_attn = torch.einsum(
60+
"bind,jd->bnij", q_r_attention_1, psi
61+
) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
62+
else:
63+
shift = 2 if q_head.shape[1] != context_len else 1
64+
r = position_embeds[self.block_index][shift - 1]
65+
# Shape n_head x d_head
66+
v = self.r_r_bias * self.scale
67+
# Shape d_model x n_head x d_head
68+
w_r = self.r_kernel
69+
70+
# Shape max_rel_len x n_head x d_model
71+
r_head = torch.einsum("td,dnh->tnh", r, w_r)
72+
# Shape batch_size x n_head x seq_len x max_rel_len
73+
positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
74+
# Shape batch_size x n_head x seq_len x context_len
75+
positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
76+
77+
if cls_mask is not None:
78+
# PATCHED
79+
positional_attn = positional_attn * cls_mask
80+
return positional_attn

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
# transformers
22
from typing import List
33
from .patch_helper import _has_transformers
4-
54
from ._patch_transformers_attention import (
65
patched_sdpa_attention_forward,
76
patched_model_bart_eager_attention_forward,
87
patched_modeling_marian_eager_attention_forward,
98
)
9+
from ._patch_transformers_generation_mixin import patched_GenerationMixin
10+
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
11+
from ._patch_transformers_rotary_embedding import (
12+
patched__compute_dynamic_ntk_parameters,
13+
patched_dynamic_rope_update,
14+
patched_GemmaRotaryEmbedding,
15+
patched_LlamaRotaryEmbedding,
16+
patched_MistralRotaryEmbedding,
17+
patched_MixtralRotaryEmbedding,
18+
patched_PhiRotaryEmbedding,
19+
)
20+
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
21+
from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
22+
23+
# transformers dependant patches
1024

1125
from ._patch_transformers_cache_utils import patch_parse_processor_args
1226

1327
if patch_parse_processor_args:
1428
from ._patch_transformers_cache_utils import patched_parse_processor_args
15-
16-
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
17-
1829
from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
1930

2031
if patch_DynamicLayer:
2132
from ._patch_transformers_dynamic_cache import patched_DynamicLayer
2233
if patch_DynamicCache:
2334
from ._patch_transformers_dynamic_cache import patched_DynamicCache
24-
25-
from ._patch_transformers_generation_mixin import patched_GenerationMixin
26-
2735
from ._patch_transformers_masking_utils import patch_masking_utils
2836

2937
if patch_masking_utils:
@@ -33,15 +41,7 @@
3341
patched_sdpa_mask_recent_torch,
3442
)
3543

36-
from ._patch_transformers_rotary_embedding import (
37-
patched__compute_dynamic_ntk_parameters,
38-
patched_dynamic_rope_update,
39-
patched_GemmaRotaryEmbedding,
40-
patched_LlamaRotaryEmbedding,
41-
patched_MistralRotaryEmbedding,
42-
patched_MixtralRotaryEmbedding,
43-
patched_PhiRotaryEmbedding,
44-
)
44+
# transformers models dependant patches
4545

4646
if _has_transformers("4.51"):
4747
from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
@@ -54,16 +54,11 @@
5454
if _has_transformers("4.53"):
5555
from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
5656

57-
# Models
58-
5957
from ._patch_transformers_gemma3 import patch_gemma3
6058

6159
if patch_gemma3:
6260
from ._patch_transformers_gemma3 import patched_Gemma3Model
6361

64-
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
65-
66-
6762
from ._patch_transformers_qwen2 import patch_qwen2
6863

6964
if patch_qwen2:
@@ -80,14 +75,17 @@
8075
patched_Qwen2_5_VLModel,
8176
PLUGS as PLUGS_Qwen25,
8277
)
83-
8478
from ._patch_transformers_qwen3 import patch_qwen3
8579

8680
if patch_qwen3:
8781
from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
82+
from ._patch_transformers_funnel import patch_funnel
8883

89-
90-
from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
84+
if patch_funnel:
85+
from ._patch_transformers_funnel import (
86+
patched_FunnelAttentionStructure,
87+
patched_FunnelRelMultiheadAttention,
88+
)
9189

9290

9391
def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ConvNextV2Model,image-feature-extraction
3333
CosmosTransformer3DModel,image-to-video
3434
CvtModel,feature-extraction
35+
ClvpModelForConditionalGeneration,audio-feature-extraction
3536
DPTModel,image-feature-extraction
3637
Data2VecAudioModel,feature-extraction
3738
Data2VecTextModel,feature-extraction
@@ -49,6 +50,7 @@
4950
ElectraModel,feature-extraction
5051
EsmModel,feature-extraction
5152
FalconMambaForCausalLM,text-generation
53+
FunnelBaseModel,feature-extraction
5254
GLPNModel,image-feature-extraction
5355
GPT2LMHeadModel,text-generation
5456
GPTBigCodeModel,feature-extraction
@@ -63,6 +65,7 @@
6365
Glm4vMoeForConditionalGeneration,image-text-to-text
6466
GraniteForCausalLM,text-generation
6567
GroupViTModel,feature-extraction
68+
HeliumForCausalLM,text-generation
6669
HieraForImageClassification,image-classification
6770
HubertModel,feature-extraction
6871
IBertModel,feature-extraction

0 commit comments

Comments
 (0)