Skip to content

Commit 67281cb

Browse files
authored
add example to export experts part (#410)
* add example to export experts part * add one more example * add example * more improvments * doc * fix * fix patch * fix * fix * fix again
1 parent 3e7d212 commit 67281cb

19 files changed

+887
-139
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
*.data
12
*.pyc
23
*.pyd
34
*.dylib

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.9.1
55
+++++
66

7+
* :pr:`410`: add patch for `_get_range_constraints`
78
* :pr:`409`: improves ModelBuilder wrapper
89
* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0, 5.2.0 (see https://github.com/huggingface/transformers/pull/43765/)
910

_doc/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def linkcode_resolve(domain, info):
141141
("py:class", "torch.fx.proxy.TracerBase"),
142142
("py:class", "torch.FloatTensor"),
143143
("py:class", "torch.LongTensor"),
144+
("py:class", "torch.export._trace.ExportArtifact"),
144145
("py:class", "torch.utils._pytree.Context"),
145146
("py:class", "torch.utils._pytree.KeyEntry"),
146147
("py:class", "torch.utils._pytree.TreeSpec"),
@@ -211,7 +212,7 @@ def linkcode_resolve(domain, info):
211212

212213
if int(os.environ.get("UNITTEST_GOING", "0")):
213214
sphinx_gallery_conf["ignore_pattern"] = (
214-
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)).*"
215+
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*"
215216
)
216217
elif pv.Version(torch.__version__) < pv.Version("2.8"):
217218
sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*"
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
.. _l-plot-optimind-export-input-observer:
3+
4+
Export OptiMind-SFT with InputObserver
5+
======================================
6+
7+
This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer`
8+
for model `microsoft/OptiMind-SFT <https://huggingface.co/microsoft/OptiMind-SFT>`_.
9+
We only export class ``GptOssExperts``.
10+
11+
Let's create a random model
12+
+++++++++++++++++++++++++++
13+
"""
14+
15+
import pandas
16+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
17+
from onnx_diagnostic import doc
18+
from onnx_diagnostic.export.api import to_onnx
19+
from onnx_diagnostic.helpers import string_type
20+
from onnx_diagnostic.torch_export_patches import (
21+
register_additional_serialization_functions,
22+
torch_export_patches,
23+
)
24+
from onnx_diagnostic.investigate.input_observer import InputObserver
25+
26+
device = "cuda"
27+
model_id = "microsoft/OptiMind-SFT"
28+
print(f"get tokenizer {model_id!r}")
29+
tokenizer = AutoTokenizer.from_pretrained(model_id)
30+
print(f"get config {model_id!r}")
31+
config = AutoConfig.from_pretrained(model_id)
32+
config.num_hidden_layers = 2
33+
config.layer_types = config.layer_types[:2]
34+
print(f"create model from config for {model_id!r}")
35+
model = AutoModelForCausalLM.from_config(config)
36+
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
37+
model = model.to(device)
38+
39+
# %%
40+
# We need to only export class GptOssExperts
41+
# ++++++++++++++++++++++++++++++++++++++++++
42+
43+
44+
export_module = None
45+
for _name, sub in model.named_modules():
46+
if sub.__class__.__name__ == "GptOssExperts":
47+
export_module = sub
48+
49+
assert export_module is not None, (
50+
f"Unable to find a submodule from class GptOssExperts in "
51+
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
52+
)
53+
54+
# %%
55+
# Let's run the model and capture inputs and outputs
56+
57+
58+
def generate_text(
59+
prompt,
60+
model,
61+
tokenizer,
62+
max_length=50,
63+
temperature=0.01,
64+
top_k=50,
65+
top_p=0.95,
66+
do_sample=True,
67+
):
68+
inputs = tokenizer(prompt, return_tensors="pt")
69+
input_ids = inputs["input_ids"].to(device)
70+
attention_mask = inputs["attention_mask"].to(device)
71+
72+
outputs = model.generate(
73+
input_ids=input_ids,
74+
attention_mask=attention_mask,
75+
max_length=max_length,
76+
temperature=temperature,
77+
top_k=top_k,
78+
top_p=top_p,
79+
do_sample=do_sample,
80+
)
81+
82+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83+
return generated_text
84+
85+
86+
prompt = "Continue: it rains, what should I do?"
87+
observer = InputObserver()
88+
with (
89+
register_additional_serialization_functions(patch_transformers=True),
90+
observer(export_module),
91+
):
92+
generate_text(prompt, model, tokenizer)
93+
94+
95+
# %%
96+
# Export
97+
# ++++++
98+
#
99+
# First, what was inferred.
100+
101+
args = observer.infer_arguments()
102+
dynamic_shapes = observer.infer_dynamic_shapes()
103+
print(f"args={string_type(args, with_shape=True, with_device=True)}")
104+
print(f"dynamic_shapes={dynamic_shapes}")
105+
106+
# %%
107+
# Next, the export.
108+
109+
110+
filename = "plot_export_optimind_experts_input_observer.onnx"
111+
with torch_export_patches(patch_transformers=True):
112+
to_onnx(
113+
export_module,
114+
args=args,
115+
filename=filename,
116+
dynamic_shapes=dynamic_shapes,
117+
exporter="custom",
118+
verbose=1,
119+
)
120+
121+
# %%
122+
# Let's measure the discrepancies.
123+
data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True)
124+
df = pandas.DataFrame(data)
125+
df.to_excel("plot_export_optimind_input_observer.xlsx")
126+
print(df)
127+
128+
# %%
129+
# Let's show the errors.
130+
for row in data:
131+
if not row["SUCCESS"] and "error" in row:
132+
print(row["error"])
133+
134+
135+
# %%
136+
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
.. _l-plot-tiny-llm-attention-export-input-observer:
3+
4+
Export attention from arnir0/Tiny-LLM with InputObserver
5+
========================================================
6+
7+
This shows how to only export attention from model
8+
`arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
9+
It uses what was shown in example
10+
:ref:`l-plot-tiny-llm-export-input-observer`.
11+
12+
Let's create a random model
13+
+++++++++++++++++++++++++++
14+
"""
15+
16+
import pandas
17+
import torch
18+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19+
from onnx_diagnostic import doc
20+
from onnx_diagnostic.export.api import to_onnx
21+
from onnx_diagnostic.helpers import string_type
22+
from onnx_diagnostic.torch_export_patches import (
23+
register_additional_serialization_functions,
24+
torch_export_patches,
25+
)
26+
from onnx_diagnostic.investigate.input_observer import InputObserver
27+
28+
device = "cuda"
29+
model_id = "arnir0/Tiny-LLM"
30+
print(f"get tokenizer {model_id!r}")
31+
tokenizer = AutoTokenizer.from_pretrained(model_id)
32+
print(f"get config {model_id!r}")
33+
config = AutoConfig.from_pretrained(model_id)
34+
print(f"create model from config for {model_id!r}")
35+
model = AutoModelForCausalLM.from_config(config)
36+
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
37+
model = model.to(device).to(torch.float16)
38+
39+
# %%
40+
# We need to only export class LlamaAttention
41+
# +++++++++++++++++++++++++++++++++++++++++++
42+
43+
44+
export_module = None
45+
for _name, sub in model.named_modules():
46+
if sub.__class__.__name__ == "LlamaAttention":
47+
export_module = sub
48+
49+
assert export_module is not None, (
50+
f"Unable to find a submodule from class LlamaAttention in "
51+
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
52+
)
53+
54+
# %%
55+
# Let's run the model and capture the inputs and outputs of the attention part.
56+
57+
58+
def generate_text(
59+
prompt,
60+
model,
61+
tokenizer,
62+
max_length=50,
63+
temperature=0.01,
64+
top_k=50,
65+
top_p=0.95,
66+
do_sample=True,
67+
):
68+
inputs = tokenizer(prompt, return_tensors="pt")
69+
input_ids = inputs["input_ids"].to(device)
70+
attention_mask = inputs["attention_mask"].to(device)
71+
72+
outputs = model.generate(
73+
input_ids=input_ids,
74+
attention_mask=attention_mask,
75+
max_length=max_length,
76+
temperature=temperature,
77+
top_k=top_k,
78+
top_p=top_p,
79+
do_sample=do_sample,
80+
)
81+
82+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83+
return generated_text
84+
85+
86+
prompt = "Continue: it rains, what should I do?"
87+
observer = InputObserver()
88+
with (
89+
register_additional_serialization_functions(patch_transformers=True),
90+
observer(export_module),
91+
):
92+
generate_text(prompt, model, tokenizer)
93+
94+
95+
# %%
96+
# Export
97+
# ++++++
98+
#
99+
# First, what was inferred.
100+
101+
kwargs = observer.infer_arguments()
102+
dynamic_shapes = observer.infer_dynamic_shapes()
103+
print("attention type:", type(export_module))
104+
print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}")
105+
print(f"dynamic_shapes={dynamic_shapes}")
106+
107+
# %%
108+
# Next, the export.
109+
110+
111+
filename = "plot_export_tiny_llm_attention_input_observer.onnx"
112+
with torch_export_patches(patch_torch=True, patch_transformers=True):
113+
to_onnx(
114+
export_module,
115+
args=(),
116+
kwargs=kwargs,
117+
filename=filename,
118+
dynamic_shapes=dynamic_shapes,
119+
exporter="custom",
120+
verbose=1,
121+
)
122+
123+
# %%
124+
# Let's measure the discrepancies.
125+
data = observer.check_discrepancies(
126+
filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True
127+
)
128+
df = pandas.DataFrame(data)
129+
df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx")
130+
print(df)
131+
132+
# %%
133+
# Let's show the errors.
134+
for row in data:
135+
if not row["SUCCESS"] and "error" in row:
136+
print(row["error"])
137+
138+
139+
# %%
140+
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

0 commit comments

Comments
 (0)