Skip to content

Commit a8235c8

Browse files
committed
add ability to do a random rotation before scalar quantization, inspired by recent works, but cite the original paper by Chee et al. from Cornell
1 parent f4ff5d1 commit a8235c8

9 files changed

Lines changed: 66 additions & 78 deletions

File tree

.github/workflows/build.yml

Lines changed: 0 additions & 14 deletions
This file was deleted.

.github/workflows/python-publish.yml

Lines changed: 0 additions & 36 deletions
This file was deleted.

.github/workflows/test.yml

Lines changed: 0 additions & 19 deletions
This file was deleted.

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,3 +815,15 @@ assert loss.item() >= 0
815815
url = {https://arxiv.org/abs/2509.10140},
816816
}
817817
```
818+
819+
```bibtex
820+
@misc{chee2024quip2bitquantizationlarge,
821+
title = {QuIP: 2-Bit Quantization of Large Language Models With Guarantees},
822+
author = {Jerry Chee and Yaohui Cai and Volodymyr Kuleshov and Christopher De Sa},
823+
year = {2024},
824+
eprint = {2307.13304},
825+
archivePrefix = {arXiv},
826+
primaryClass = {cs.LG},
827+
url = {https://arxiv.org/abs/2307.13304},
828+
}
829+
```

examples/autoencoder_fsq.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def default(val, d):
3232

3333
# classes
3434

35-
def SimpleFSQAutoEncoder(levels: list[int]):
35+
def SimpleFSQAutoEncoder(levels: list[int], orthogonal_rotation: bool = False):
3636
return Sequential(
3737
nn.Conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1),
3838
nn.MaxPool2d(kernel_size = 2, stride = 2),
3939
nn.GELU(),
4040
nn.Conv2d(16, 32, kernel_size = 3, stride = 1, padding = 1),
4141
nn.MaxPool2d(kernel_size = 2, stride = 2),
4242
nn.Conv2d(32, len(levels), kernel_size = 1),
43-
FSQ(levels),
43+
FSQ(levels, orthogonal_rotation = orthogonal_rotation),
4444
nn.Conv2d(len(levels), 32, kernel_size = 3, stride = 1, padding = 1),
4545
nn.Upsample(scale_factor = 2, mode = "nearest"),
4646
nn.Conv2d(32, 16, kernel_size = 3, stride = 1, padding = 1),
@@ -54,14 +54,15 @@ def train(
5454
lr = 3e-4,
5555
levels = [8, 6, 5],
5656
seed = 1234,
57-
batch_size = 256
57+
batch_size = 256,
58+
orthogonal_rotation = False
5859
):
5960
torch.random.manual_seed(seed)
6061
device = "cuda" if torch.cuda.is_available() else "cpu"
6162

6263
num_codes = math.prod(levels)
6364

64-
model = SimpleFSQAutoEncoder(levels).to(device)
65+
model = SimpleFSQAutoEncoder(levels, orthogonal_rotation = orthogonal_rotation).to(device)
6566

6667
opt = AdamW(model.parameters(), lr = lr)
6768

examples/autoencoder_lfq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def train(
6565
entropy_loss_weight = 0.02,
6666
diversity_gamma = 1.,
6767
spherical = True,
68-
batch_size = 256
68+
batch_size = 256,
69+
orthogonal_rotation = False
6970
):
7071
torch.random.manual_seed(seed)
7172
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -74,7 +75,8 @@ def train(
7475
codebook_size = codebook_size,
7576
entropy_loss_weight = entropy_loss_weight,
7677
diversity_gamma = diversity_gamma,
77-
spherical = spherical
78+
spherical = spherical,
79+
orthogonal_rotation = orthogonal_rotation
7880
).to(device)
7981

8082
opt = AdamW(model.parameters(), lr = lr)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.28.0"
3+
version = "1.28.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def __init__(
7676
force_quantization_f32 = True,
7777
preserve_symmetry = False,
7878
noise_dropout = 0.,
79-
bound_hard_clamp = False # for residual fsq, if input is pre-softclamped to the right range
79+
bound_hard_clamp = False, # for residual fsq, if input is pre-softclamped to the right range
80+
orthogonal_rotation = False # increase codebook utilization. ensure levels are symmetric! https://arxiv.org/abs/2307.13304v2
8081
):
8182
super().__init__()
8283

@@ -132,6 +133,17 @@ def __init__(
132133

133134
self.bound_hard_clamp = bound_hard_clamp
134135

136+
self.orthogonal_rotation = orthogonal_rotation
137+
138+
if orthogonal_rotation:
139+
is_symmetric = len(set(levels)) == 1
140+
if not is_symmetric:
141+
print('orthogonal_rotation is not recommended for FSQ with asymmetric levels (i.e. where the number of bins differ across dimensions)')
142+
143+
orthogonal_rot = torch.empty(codebook_dim, codebook_dim)
144+
nn.init.orthogonal_(orthogonal_rot)
145+
self.register_buffer('orthogonal_rot', orthogonal_rot)
146+
135147
def bound(self, z, eps = 1e-3, hard_clamp = False):
136148
""" Bound `z`, an array of shape (..., d). """
137149
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
@@ -219,6 +231,9 @@ def indices_to_codes(self, indices):
219231

220232
codes = self._indices_to_codes(indices)
221233

234+
if self.orthogonal_rotation:
235+
codes = codes @ self.orthogonal_rot.t()
236+
222237
if self.keep_num_codebooks_dim:
223238
codes = rearrange(codes, '... c d -> ... (c d)')
224239

@@ -253,6 +268,9 @@ def forward(self, z):
253268

254269
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
255270

271+
if self.orthogonal_rotation:
272+
z = z @ self.orthogonal_rot
273+
256274
# whether to force quantization step to be full precision or not
257275

258276
force_f32 = self.force_quantization_f32
@@ -275,6 +293,9 @@ def forward(self, z):
275293

276294
codes = self.maybe_apply_noise(codes)
277295

296+
if self.orthogonal_rotation:
297+
codes = codes @ self.orthogonal_rot.t()
298+
278299
codes = rearrange(codes, 'b n c d -> b n (c d)')
279300

280301
codes = codes.to(orig_dtype)

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def __init__(
116116
experimental_softplus_entropy_loss = False,
117117
entropy_loss_offset = 5., # how much to shift the loss before softplus
118118
spherical = False, # from https://arxiv.org/abs/2406.07548
119-
force_quantization_f32 = True # will force the quantization step to be full precision
119+
force_quantization_f32 = True, # will force the quantization step to be full precision
120+
orthogonal_rotation = False # increase codebook utilization without aux losses, inspired by https://arxiv.org/abs/2307.13304v2
120121
):
121122
super().__init__()
122123

@@ -165,6 +166,15 @@ def __init__(
165166
self.spherical = spherical
166167
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
167168

169+
# orthogonal rotation
170+
171+
self.orthogonal_rotation = orthogonal_rotation
172+
173+
if orthogonal_rotation:
174+
orthogonal_rot = torch.empty(codebook_dim, codebook_dim)
175+
nn.init.orthogonal_(orthogonal_rot)
176+
self.register_buffer('orthogonal_rot', orthogonal_rot)
177+
168178
# entropy aux loss related weights
169179

170180
assert 0 < frac_per_sample_entropy <= 1.
@@ -234,6 +244,9 @@ def indices_to_codes(
234244

235245
codes = self.maybe_l2norm(codes)
236246

247+
if self.orthogonal_rotation:
248+
codes = codes @ self.orthogonal_rot.t()
249+
237250
codes = rearrange(codes, '... c d -> ... (c d)')
238251

239252
# whether to project codes out to original dimensions
@@ -287,6 +300,9 @@ def forward(
287300

288301
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
289302

303+
if self.orthogonal_rotation:
304+
x = x @ self.orthogonal_rot
305+
290306
# maybe l2norm
291307

292308
x = self.maybe_l2norm(x)
@@ -412,6 +428,11 @@ def forward(
412428
if force_f32:
413429
x = x.type(orig_dtype)
414430

431+
# rotate back if needed
432+
433+
if self.orthogonal_rotation:
434+
x = x @ self.orthogonal_rot.t()
435+
415436
# merge back codebook dim
416437

417438
x = rearrange(x, 'b n c d -> b n (c d)')

0 commit comments

Comments
 (0)