Skip to content

Commit 3c361b1

Browse files
authored
Support options and SlidingWindowCache (#65)
* better design * add missing doc * fix mypy * more refac * end refa * or * last issue * tests * fix ut
1 parent 54245d1 commit 3c361b1

14 files changed

Lines changed: 448 additions & 213 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
19-
transformers: ['4.48.3', '4.51.2', 'main']
19+
transformers: ['4.48.3', '4.51.3', 'main']
2020
torch: ['2.6', 'main']
2121

2222
steps:

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`65`: support SlidingWindowCache
78
* :pr:`63`: support option ``--trained``
89
* :pr:`61`: improves dynamic shapes for EncoderDecoderCache
910
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
("py:class", "transformers.cache_utils.DynamicCache"),
124124
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
125125
("py:class", "transformers.cache_utils.MambaCache"),
126+
("py:class", "transformers.cache_utils.SlidingWindowCache"),
126127
("py:class", "transformers.configuration_utils.PretrainedConfig"),
127128
("py:func", "torch.export._draft_export.draft_export"),
128129
("py:func", "torch._export.tools.report_exportability"),

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import unittest
22
import torch
33
import transformers
4-
from onnx_diagnostic.ext_test_case import ExtTestCase
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
55
from onnx_diagnostic.helpers import string_type
66
from onnx_diagnostic.helpers.cache_helper import (
7+
flatten_unflatten_for_dynamic_shapes,
78
make_dynamic_cache,
89
make_encoder_decoder_cache,
9-
flatten_unflatten_for_dynamic_shapes,
10+
make_mamba_cache,
11+
make_sliding_window_cache,
1012
)
1113
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1214
from onnx_diagnostic.torch_export_patches.patch_inputs import (
@@ -132,6 +134,37 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
132134
self.string_type(c2, with_shape=True),
133135
)
134136

137+
@requires_transformers("4.51") # the structure changes
138+
def test_make_mamba_cache(self):
139+
cache = make_mamba_cache(
140+
[
141+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
142+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
143+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
144+
]
145+
)
146+
text = self.string_type(cache, with_shape=True)
147+
self.assertEqual(
148+
"MambaCache(conv_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4], "
149+
"ssm_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4])",
150+
text,
151+
)
152+
153+
def test_make_sliding_window_cache(self):
154+
cache = make_sliding_window_cache(
155+
[
156+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
157+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
158+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
159+
]
160+
)
161+
text = self.string_type(cache, with_shape=True)
162+
self.assertEqual(
163+
"SlidingWindowCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
164+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
165+
text,
166+
)
167+
135168

136169
if __name__ == "__main__":
137170
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import unittest
22
import torch
33
from transformers.modeling_outputs import BaseModelOutput
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
55
from onnx_diagnostic.helpers.cache_helper import (
66
make_encoder_decoder_cache,
77
make_dynamic_cache,
8+
make_sliding_window_cache,
89
flatten_unflatten_for_dynamic_shapes,
910
)
1011
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -164,6 +165,53 @@ def test_base_model_output_unflatten_flatten(self):
164165
self.assertIsInstance(unflat, dict)
165166
self.assertEqual(list(unflat), ["last_hidden_state"])
166167

168+
@ignore_warnings(UserWarning)
169+
def test_base_sliding_window_cache_unflatten_flatten(self):
170+
cache = make_sliding_window_cache(
171+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
172+
)
173+
with bypass_export_some_errors():
174+
cache2 = torch_deepcopy([cache])
175+
self.assertEqualAny([cache], cache2)
176+
177+
@ignore_warnings(UserWarning)
178+
@requires_torch("2.7")
179+
def test_sliding_window_cache_export(self):
180+
class Model(torch.nn.Module):
181+
def forward(self, cache):
182+
return cache.key_cache[0]
183+
184+
cache = make_sliding_window_cache(
185+
[
186+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
187+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
188+
]
189+
)
190+
model = Model()
191+
model(cache)
192+
DYN = torch.export.Dim.DYNAMIC
193+
ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]
194+
195+
with bypass_export_some_errors(patch_transformers=True):
196+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
197+
198+
@ignore_warnings(UserWarning)
199+
def test_sliding_window_cache_flatten(self):
200+
cache = make_sliding_window_cache(
201+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))]
202+
)
203+
with bypass_export_some_errors():
204+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
205+
self.assertEqual(
206+
"#2[T1s4x4x4x4,T1s4x4x4x4]",
207+
self.string_type(flat, with_shape=True),
208+
)
209+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
210+
self.assertEqual(
211+
self.string_type(cache, with_shape=True, with_min_max=True),
212+
self.string_type(cache2, with_shape=True, with_min_max=True),
213+
)
214+
167215

168216
if __name__ == "__main__":
169217
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import json
23
import sys
34
import textwrap
@@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]):
227228
print(f"task: {task_from_id(args.mid)}")
228229

229230

231+
class _ParseDict(argparse.Action):
232+
def __call__(self, parser, namespace, values, option_string=None):
233+
d = getattr(namespace, self.dest) or {}
234+
235+
if values:
236+
for item in values:
237+
split_items = item.split("=", 1)
238+
key = split_items[0].strip() # we remove blanks around keys, as is logical
239+
value = split_items[1]
240+
241+
d[key] = value
242+
243+
setattr(namespace, self.dest, d)
244+
245+
230246
def get_parser_validate() -> ArgumentParser:
231247
parser = ArgumentParser(
232248
prog="test",
@@ -297,6 +313,14 @@ def get_parser_validate() -> ArgumentParser:
297313
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
298314
parser.add_argument("--dtype", help="changes dtype if necessary")
299315
parser.add_argument("--device", help="changes the device if necessary")
316+
parser.add_argument(
317+
"--iop",
318+
metavar="KEY=VALUE",
319+
nargs="*",
320+
help="Additional input options, use to change the default "
321+
"inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322+
action=_ParseDict,
323+
)
300324
return parser
301325

302326

@@ -346,6 +370,7 @@ def _cmd_validate(argv: List[Any]):
346370
dump_folder=args.dump_folder,
347371
drop_inputs=None if not args.drop else args.drop.split(","),
348372
ortfusiontype=args.ortfusiontype,
373+
input_options=args.iop,
349374
)
350375
print("")
351376
print("-- summary --")

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def assertEqualAny(
939939
else:
940940
for e, g in zip(expected, value):
941941
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
942-
elif expected.__class__.__name__ == "DynamicCache":
942+
elif expected.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
943943
self.assertEqual(type(expected), type(value), msg=msg)
944944
atts = ["key_cache", "value_cache"]
945945
self.assertEqualAny(

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,8 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
2626
subtrees = []
2727
for subspec in spec.children_specs:
2828
end += subspec.num_leaves
29-
if use_dict and (subspec.type is dict or subspec.context):
30-
value = subspec.unflatten(flat[start:end])
31-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
32-
else:
33-
value = subspec.unflatten(flat[start:end])
34-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
29+
value = subspec.unflatten(flat[start:end])
30+
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
3531
subtrees.append(value)
3632
start = end
3733
if use_dict and (spec.type is dict or spec.context):
@@ -185,3 +181,36 @@ def __init__(self):
185181
)
186182
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
187183
return cache
184+
185+
186+
def make_sliding_window_cache(
187+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
188+
) -> transformers.cache_utils.MambaCache:
189+
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
190+
191+
class _config:
192+
def __init__(self):
193+
self.head_dim = key_value_pairs[0][0].shape[-1]
194+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
195+
self.num_hidden_layers = len(key_value_pairs)
196+
self.sliding_window = key_value_pairs[0][0].shape[2]
197+
198+
cache = transformers.cache_utils.SlidingWindowCache(
199+
_config(),
200+
max_batch_size=key_value_pairs[0][0].shape[0],
201+
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
202+
device=key_value_pairs[0][0].device,
203+
dtype=key_value_pairs[0][0].dtype,
204+
)
205+
for i in range(len(key_value_pairs)):
206+
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
207+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
208+
f"got {key_value_pairs[i][0].shape}"
209+
)
210+
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
211+
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
212+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
213+
f"got {key_value_pairs[i][1].shape}"
214+
)
215+
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
216+
return cache

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def string_type(
534534
print(f"[string_type] CACHE1:{type(obj)}")
535535
return f"MambaCache(conv_states={c}, ssm_states={d})"
536536

537-
if obj.__class__.__name__ == "DynamicCache":
537+
if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
538538
kc = string_type(
539539
obj.key_cache,
540540
with_shape=with_shape,

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import numpy as np
55
import torch
66
from .helper import string_type
7-
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
7+
from .cache_helper import (
8+
make_dynamic_cache,
9+
make_encoder_decoder_cache,
10+
make_sliding_window_cache,
11+
)
812

913

1014
def _forward_(*args, _f=None, _context=None, **kwargs):
@@ -363,6 +367,10 @@ def torch_deepcopy(value: Any) -> Any:
363367
return make_dynamic_cache(
364368
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
365369
)
370+
if value.__class__.__name__ == "SlidingWindowCache":
371+
return make_sliding_window_cache(
372+
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
373+
)
366374
if value.__class__.__name__ == "EncoderDecoderCache":
367375
return make_encoder_decoder_cache(
368376
torch_deepcopy(value.self_attention_cache),

0 commit comments

Comments
 (0)