|
7 | 7 | has_transformers, |
8 | 8 | ignore_warnings, |
9 | 9 | requires_transformers, |
| 10 | + requires_experimental_experiment, |
10 | 11 | ) |
11 | 12 | from onnx_diagnostic.helpers import max_diff |
12 | 13 | from onnx_diagnostic.helpers.torch_helper import torch_deepcopy |
13 | 14 | from onnx_diagnostic.helpers.rt_helper import make_feeds |
14 | 15 | from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache |
15 | 16 | from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs |
16 | 17 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
17 | | -from onnx_diagnostic.export.api import to_onnx |
| 18 | +from onnx_diagnostic.export.api import to_onnx, method_to_onnx |
18 | 19 |
|
19 | 20 |
|
20 | 21 | class TestValidate(ExtTestCase): |
@@ -114,6 +115,136 @@ def test_tiny_llm_to_onnx(self): |
114 | 115 |
|
115 | 116 | self.clean_dump() |
116 | 117 |
|
| 118 | + @requires_experimental_experiment("0.1") |
| 119 | + def test_method_to_onnx_args(self): |
| 120 | + class Model(torch.nn.Module): |
| 121 | + def forward(self, x, y): |
| 122 | + return x + y |
| 123 | + |
| 124 | + filename = self.get_dump_file("test_method_to_onnx_args.onnx") |
| 125 | + inputs = [ |
| 126 | + (torch.randn((5, 6)), torch.randn((1, 6))), |
| 127 | + (torch.randn((7, 7)), torch.randn((1, 7))), |
| 128 | + ] |
| 129 | + model = Model() |
| 130 | + method_to_call = method_to_onnx(model, exporter="custom", filename=filename) |
| 131 | + expecteds = [] |
| 132 | + for args in inputs: |
| 133 | + expecteds.append(method_to_call(*args)) |
| 134 | + self.assertExists(filename) |
| 135 | + src = method_to_call._method_src |
| 136 | + self.assertIn("f(self, x, y):", src) |
| 137 | + self.assertIn("return self._method_call(x=x, y=y)", src) |
| 138 | + self.assertEqual(len(list(method_to_call.named_modules())), 2) |
| 139 | + sess = self.check_ort(filename) |
| 140 | + input_names = [i.name for i in sess.get_inputs()] |
| 141 | + for expected, args in zip(expecteds, inputs): |
| 142 | + feeds = make_feeds(input_names, args, use_numpy=True) |
| 143 | + got = sess.run(None, feeds) |
| 144 | + self.assertEqualArray(expected, got[0]) |
| 145 | + self.clean_dump() |
| 146 | + |
| 147 | + @requires_experimental_experiment("0.1") |
| 148 | + def test_method_to_onnx_kwargs(self): |
| 149 | + class Model(torch.nn.Module): |
| 150 | + def forward(self, x=None, y=None): |
| 151 | + return x + y |
| 152 | + |
| 153 | + filename = self.get_dump_file("test_method_to_onnx_kwargs.onnx") |
| 154 | + inputs = [ |
| 155 | + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))), |
| 156 | + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))), |
| 157 | + ] |
| 158 | + model = Model() |
| 159 | + method_to_call = method_to_onnx(model, exporter="custom", filename=filename) |
| 160 | + expecteds = [] |
| 161 | + for kwargs in inputs: |
| 162 | + expecteds.append(method_to_call(**kwargs)) |
| 163 | + self.assertExists(filename) |
| 164 | + src = method_to_call._method_src |
| 165 | + self.assertIn("f(self, x=None, y=None):", src) |
| 166 | + self.assertIn("return self._method_call(x=x, y=y)", src) |
| 167 | + self.assertEqual(len(list(method_to_call.named_modules())), 2) |
| 168 | + sess = self.check_ort(filename) |
| 169 | + input_names = [i.name for i in sess.get_inputs()] |
| 170 | + for expected, kwargs in zip(expecteds, inputs): |
| 171 | + feeds = make_feeds(input_names, kwargs, use_numpy=True) |
| 172 | + got = sess.run(None, feeds) |
| 173 | + self.assertEqualArray(expected, got[0]) |
| 174 | + self.clean_dump() |
| 175 | + |
| 176 | + @requires_experimental_experiment("0.1") |
| 177 | + def test_method_to_onnx_kwargs_patch(self): |
| 178 | + class Model(torch.nn.Module): |
| 179 | + def forward(self, x=None, y=None): |
| 180 | + return x + y |
| 181 | + |
| 182 | + filename = self.get_dump_file("test_method_to_onnx_kwargs_patch.onnx") |
| 183 | + inputs = [ |
| 184 | + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))), |
| 185 | + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))), |
| 186 | + ] |
| 187 | + model = Model() |
| 188 | + method_to_call = method_to_onnx( |
| 189 | + model, |
| 190 | + exporter="custom", |
| 191 | + filename=filename, |
| 192 | + patch_kwargs=dict(patch_transformers=True), |
| 193 | + ) |
| 194 | + expecteds = [] |
| 195 | + for kwargs in inputs: |
| 196 | + expecteds.append(method_to_call(**kwargs)) |
| 197 | + self.assertExists(filename) |
| 198 | + src = method_to_call._method_src |
| 199 | + self.assertIn("f(self, x=None, y=None):", src) |
| 200 | + self.assertIn("return self._method_call(x=x, y=y)", src) |
| 201 | + self.assertEqual(len(list(method_to_call.named_modules())), 2) |
| 202 | + sess = self.check_ort(filename) |
| 203 | + input_names = [i.name for i in sess.get_inputs()] |
| 204 | + for expected, kwargs in zip(expecteds, inputs): |
| 205 | + feeds = make_feeds(input_names, kwargs, use_numpy=True) |
| 206 | + got = sess.run(None, feeds) |
| 207 | + self.assertEqualArray(expected, got[0]) |
| 208 | + self.clean_dump() |
| 209 | + |
| 210 | + @requires_experimental_experiment("0.1") |
| 211 | + @hide_stdout() |
| 212 | + def test_method_to_onnx_mixed(self): |
| 213 | + from experimental_experiment.torch_interpreter import ExportOptions |
| 214 | + |
| 215 | + class Model(torch.nn.Module): |
| 216 | + def forward(self, x, y=None): |
| 217 | + return x + y |
| 218 | + |
| 219 | + filename = self.get_dump_file("test_method_to_onnx_mixed.onnx") |
| 220 | + inputs = [ |
| 221 | + ((torch.randn((5, 6)),), dict(y=torch.randn((1, 6)))), |
| 222 | + ((torch.randn((7, 7)),), dict(y=torch.randn((1, 7)))), |
| 223 | + ] |
| 224 | + model = Model() |
| 225 | + method_to_call = method_to_onnx( |
| 226 | + model, |
| 227 | + exporter="custom", |
| 228 | + filename=filename, |
| 229 | + verbose=10, |
| 230 | + exporter_kwargs=dict(export_options=ExportOptions(backed_size_oblivious=False)), |
| 231 | + ) |
| 232 | + expecteds = [] |
| 233 | + for args, kwargs in inputs: |
| 234 | + expecteds.append(method_to_call(*args, **kwargs)) |
| 235 | + self.assertExists(filename) |
| 236 | + src = method_to_call._method_src |
| 237 | + self.assertIn("f(self, x, y=None):", src) |
| 238 | + self.assertIn("return self._method_call(x=x, y=y)", src) |
| 239 | + self.assertEqual(len(list(method_to_call.named_modules())), 2) |
| 240 | + sess = self.check_ort(filename) |
| 241 | + input_names = [i.name for i in sess.get_inputs()] |
| 242 | + for expected, (args, kwargs) in zip(expecteds, inputs): |
| 243 | + feeds = make_feeds(input_names, (args, kwargs), use_numpy=True) |
| 244 | + got = sess.run(None, feeds) |
| 245 | + self.assertEqualArray(expected, got[0]) |
| 246 | + self.clean_dump() |
| 247 | + |
117 | 248 |
|
118 | 249 | if __name__ == "__main__": |
119 | 250 | unittest.main(verbosity=2) |
0 commit comments