Skip to content

Commit f4fbaa8

Browse files
authored
export a method to onnx in order to export using method generate (#375)
* export a method to onnx * mypy * mypy * fix missing args * add one example * fix * fix * fix * fix * doc * disable
1 parent 7cc8de9 commit f4fbaa8

9 files changed

Lines changed: 562 additions & 8 deletions

File tree

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.8
55
+++++
66

7+
* :pr:`375`: export a method to onnx in order to export using method generate
78
* :pr:`376`: fix patched lazy_initialization for transformers>=5
89
* :pr:`372`: fix patch on rotary embedding
910
* :pr:`371`: fix make_fake_with_dynamic_dimensions

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ Enlightening Examples
7373

7474
* `Export microsoft/phi-2
7575
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
76+
* `Export a model through method generate (with Tiny-LLM)
77+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_method_generate.html>`_
7678

7779
**Torch Export**
7880

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
.. _l-plot-tiny-llm-export-method-generate:
3+
4+
Export a model through method generate (with Tiny-LLM)
5+
======================================================
6+
7+
The main issue when exporting a LLM is the example on HuggingFace is
8+
based on method generate but we only need to export the forward method.
9+
Example :ref:`l-plot-tiny-llm-export` gives details on how to guess
10+
dummy inputs and dynamic shapes to do so.
11+
Let's see how to simplify that.
12+
13+
Dummy Example
14+
+++++++++++++
15+
16+
Let's use the example provided on
17+
`arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
18+
"""
19+
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
from onnx_diagnostic import doc
22+
from onnx_diagnostic.export.api import method_to_onnx
23+
24+
25+
MODEL_NAME = "arnir0/Tiny-LLM"
26+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
27+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
28+
29+
30+
def generate_text(
31+
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
32+
):
33+
inputs = tokenizer.encode(prompt, return_tensors="pt")
34+
35+
outputs = model.generate(
36+
inputs,
37+
max_length=max_length,
38+
temperature=temperature,
39+
top_k=top_k,
40+
top_p=top_p,
41+
do_sample=True,
42+
)
43+
44+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
45+
return generated_text
46+
47+
# Define your prompt
48+
49+
50+
prompt = "Continue: it rains..."
51+
generated_text = generate_text(prompt, model, tokenizer)
52+
print("-----------------")
53+
print(generated_text)
54+
print("-----------------")
55+
56+
# %%
57+
# Replace forward method
58+
# ++++++++++++++++++++++
59+
#
60+
# We now modify the model to export the model by replacing the forward method.
61+
filename = "plot_export_tiny_llm_method_generate.onnx"
62+
forward_replacement = method_to_onnx(
63+
model,
64+
method_name="forward",
65+
exporter="custom",
66+
filename=filename,
67+
patch_kwargs=dict(patch_transformers=True),
68+
verbose=1,
69+
convert_after_n_calls=3,
70+
skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"},
71+
dynamic_shapes={
72+
"cache_position": {0: "total_sequence_length"},
73+
"past_key_values": [
74+
{0: "batch_size", 2: "past_sequence_length"},
75+
{0: "batch_size", 2: "past_sequence_length"},
76+
],
77+
"input_ids": {0: "batch_size", 1: "sequence_length"},
78+
},
79+
)
80+
81+
# %%
82+
# The lambda function cannot be skipped as
83+
# forward_replacement is a module.
84+
85+
print(f"type(forward_replacement)={type(forward_replacement)}")
86+
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
87+
88+
89+
# %%
90+
# Let's call generate again.
91+
generated_text = generate_text(prompt, model, tokenizer)
92+
print(generated_text)
93+
94+
95+
# %%
96+
97+
doc.plot_legend("Tiny-LLM\nforward inputs\through generate", "torch.export.export", "tomato")

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Enlightening Examples
8585
**Where to start to export a model**
8686

8787
* :ref:`l-plot-export_tiny_phi2`
88+
* :ref:`l-plot-tiny-llm-export-method-generate`
8889

8990
**Exporter Recipes**
9091

_unittests/ut_export/test_api.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
has_transformers,
88
ignore_warnings,
99
requires_transformers,
10+
requires_experimental_experiment,
1011
)
1112
from onnx_diagnostic.helpers import max_diff
1213
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1314
from onnx_diagnostic.helpers.rt_helper import make_feeds
1415
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
1516
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1617
from onnx_diagnostic.torch_export_patches import torch_export_patches
17-
from onnx_diagnostic.export.api import to_onnx
18+
from onnx_diagnostic.export.api import to_onnx, method_to_onnx
1819

1920

2021
class TestValidate(ExtTestCase):
@@ -114,6 +115,136 @@ def test_tiny_llm_to_onnx(self):
114115

115116
self.clean_dump()
116117

118+
@requires_experimental_experiment("0.1")
119+
def test_method_to_onnx_args(self):
120+
class Model(torch.nn.Module):
121+
def forward(self, x, y):
122+
return x + y
123+
124+
filename = self.get_dump_file("test_method_to_onnx_args.onnx")
125+
inputs = [
126+
(torch.randn((5, 6)), torch.randn((1, 6))),
127+
(torch.randn((7, 7)), torch.randn((1, 7))),
128+
]
129+
model = Model()
130+
method_to_call = method_to_onnx(model, exporter="custom", filename=filename)
131+
expecteds = []
132+
for args in inputs:
133+
expecteds.append(method_to_call(*args))
134+
self.assertExists(filename)
135+
src = method_to_call._method_src
136+
self.assertIn("f(self, x, y):", src)
137+
self.assertIn("return self._method_call(x=x, y=y)", src)
138+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
139+
sess = self.check_ort(filename)
140+
input_names = [i.name for i in sess.get_inputs()]
141+
for expected, args in zip(expecteds, inputs):
142+
feeds = make_feeds(input_names, args, use_numpy=True)
143+
got = sess.run(None, feeds)
144+
self.assertEqualArray(expected, got[0])
145+
self.clean_dump()
146+
147+
@requires_experimental_experiment("0.1")
148+
def test_method_to_onnx_kwargs(self):
149+
class Model(torch.nn.Module):
150+
def forward(self, x=None, y=None):
151+
return x + y
152+
153+
filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx")
154+
inputs = [
155+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))),
156+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))),
157+
]
158+
model = Model()
159+
method_to_call = method_to_onnx(model, exporter="custom", filename=filename)
160+
expecteds = []
161+
for kwargs in inputs:
162+
expecteds.append(method_to_call(**kwargs))
163+
self.assertExists(filename)
164+
src = method_to_call._method_src
165+
self.assertIn("f(self, x=None, y=None):", src)
166+
self.assertIn("return self._method_call(x=x, y=y)", src)
167+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
168+
sess = self.check_ort(filename)
169+
input_names = [i.name for i in sess.get_inputs()]
170+
for expected, kwargs in zip(expecteds, inputs):
171+
feeds = make_feeds(input_names, kwargs, use_numpy=True)
172+
got = sess.run(None, feeds)
173+
self.assertEqualArray(expected, got[0])
174+
self.clean_dump()
175+
176+
@requires_experimental_experiment("0.1")
177+
def test_method_to_onnx_kwargs_patch(self):
178+
class Model(torch.nn.Module):
179+
def forward(self, x=None, y=None):
180+
return x + y
181+
182+
filename = self.get_dump_file("test_method_to_onnx_kwargs_patch.onnx")
183+
inputs = [
184+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))),
185+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))),
186+
]
187+
model = Model()
188+
method_to_call = method_to_onnx(
189+
model,
190+
exporter="custom",
191+
filename=filename,
192+
patch_kwargs=dict(patch_transformers=True),
193+
)
194+
expecteds = []
195+
for kwargs in inputs:
196+
expecteds.append(method_to_call(**kwargs))
197+
self.assertExists(filename)
198+
src = method_to_call._method_src
199+
self.assertIn("f(self, x=None, y=None):", src)
200+
self.assertIn("return self._method_call(x=x, y=y)", src)
201+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
202+
sess = self.check_ort(filename)
203+
input_names = [i.name for i in sess.get_inputs()]
204+
for expected, kwargs in zip(expecteds, inputs):
205+
feeds = make_feeds(input_names, kwargs, use_numpy=True)
206+
got = sess.run(None, feeds)
207+
self.assertEqualArray(expected, got[0])
208+
self.clean_dump()
209+
210+
@requires_experimental_experiment("0.1")
211+
@hide_stdout()
212+
def test_method_to_onnx_mixed(self):
213+
from experimental_experiment.torch_interpreter import ExportOptions
214+
215+
class Model(torch.nn.Module):
216+
def forward(self, x, y=None):
217+
return x + y
218+
219+
filename = self.get_dump_file("test_method_to_onnx_mixed.onnx")
220+
inputs = [
221+
((torch.randn((5, 6)),), dict(y=torch.randn((1, 6)))),
222+
((torch.randn((7, 7)),), dict(y=torch.randn((1, 7)))),
223+
]
224+
model = Model()
225+
method_to_call = method_to_onnx(
226+
model,
227+
exporter="custom",
228+
filename=filename,
229+
verbose=10,
230+
exporter_kwargs=dict(export_options=ExportOptions(backed_size_oblivious=False)),
231+
)
232+
expecteds = []
233+
for args, kwargs in inputs:
234+
expecteds.append(method_to_call(*args, **kwargs))
235+
self.assertExists(filename)
236+
src = method_to_call._method_src
237+
self.assertIn("f(self, x, y=None):", src)
238+
self.assertIn("return self._method_call(x=x, y=y)", src)
239+
self.assertEqual(len(list(method_to_call.named_modules())), 2)
240+
sess = self.check_ort(filename)
241+
input_names = [i.name for i in sess.get_inputs()]
242+
for expected, (args, kwargs) in zip(expecteds, inputs):
243+
feeds = make_feeds(input_names, (args, kwargs), use_numpy=True)
244+
got = sess.run(None, feeds)
245+
self.assertEqualArray(expected, got[0])
246+
self.clean_dump()
247+
117248

118249
if __name__ == "__main__":
119250
unittest.main(verbosity=2)

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def forward(self, **kwargs):
188188
expected = "#2[((),dict(x:T1s5x6)),((),dict(x:T1s6x6))]"
189189
self.assertEqual(expected, string_type(mi.inputs, with_shape=True))
190190
ds = mi.guess_dynamic_shapes()
191-
self.assertEqual(ds, (tuple(), {"x": {0: torch.export.Dim.DYNAMIC}}))
191+
self.assertEqual((tuple(), {"x": {0: torch.export.Dim.DYNAMIC}}), ds)
192192
_a, _kw, ds = mi.move_to_kwargs(*mi.inputs[0], ds)
193193
self.assertEqual(ds, (tuple(), {"kwargs": {"x": {0: torch.export.Dim.DYNAMIC}}}))
194194
self.assertEqual(
@@ -937,6 +937,31 @@ def test_invalid_dimensions_for_export(self):
937937
backed_size_oblivious = cpl.invalid_dimensions_for_export()
938938
self.assertFalse(backed_size_oblivious)
939939

940+
def test_guess_dynamic_shapes_missing(self):
941+
class Model(torch.nn.Module):
942+
def forward(self, x, y=None):
943+
if y is None:
944+
return x.abs()
945+
return x.abs() + y
946+
947+
model = Model()
948+
x = torch.randn((5, 6))
949+
y = model(x=x)
950+
self.assertNotEmpty(y)
951+
952+
inputs = [
953+
(tuple(), {"x": x}),
954+
(tuple(), {"x": torch.randn((6, 6)), "y": torch.randn((6, 6))}),
955+
(tuple(), {"x": torch.randn((7, 6)), "y": torch.randn((7, 6))}),
956+
]
957+
958+
mi = ModelInputs(model, inputs)
959+
ds = mi.guess_dynamic_shapes()
960+
DYN = torch.export.Dim.DYNAMIC
961+
self.assertEqual(ds, ((), {"x": {0: DYN}, "y": {0: DYN}}))
962+
_a, _kw, ds = mi.move_to_kwargs(*mi.inputs[-1], ds)
963+
self.assertEqual(ds, (tuple(), {"x": {0: DYN}, "y": {0: DYN}}))
964+
940965

941966
if __name__ == "__main__":
942967
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def add_test_methods(cls):
102102

103103
if (
104104
not reason
105-
and name in {"plot_export_tiny_phi2.py"}
105+
and name
106+
in {"plot_export_tiny_phi2.py", "plot_export_tiny_llm_method_generate.py"}
106107
and not has_transformers("4.55")
107108
):
108109
reason = "transformers<4.55"
@@ -124,6 +125,7 @@ def add_test_methods(cls):
124125
"plot_export_locate_issue.py",
125126
"plot_export_with_auto.py",
126127
"plot_export_tiny_llm.py",
128+
"plot_export_tiny_llm_method_generate.py",
127129
}
128130
and not has_torch("2.8")
129131
):

0 commit comments

Comments
 (0)