This repository provides single-file hackable, scalable and efficient ๐ pure PyTorch implementation for CATs.
Attention and Compression is all you need for Controllably Efficient Language Models
Jatin Prakash, Aahlad Puli, Rajesh Ranganath
New York University
- simple architecture that employs two simple well-known ingredients: dense attention and compression.
- provides a controllable knob at test-time to trade-off quality for efficiency, interpolating between dense transformer and efficient alternatives, all without any retraining.
- can be used as a drop-in replacement for dense/linear attention layers in any architecture to create controllably ๐น๏ธ efficient architectures.
-
CATs model chunks of tokens given compressed representations of past chunks in the sequence ๐ธ.
-
No need to heuristically define attention masks; no need for handcrafted and complex recurrent state update rules; no need to carefully compose with attention at specific layers to have a capable architecture ๐โโ๏ธ๐.
The troubled cat ๐ฟ below describes the overwhelming feeling of designing an efficient architecture
- Due to compression resulting in a reduced sequence length, compute FLOPs & KV-cache memory diminish by a factor of chunk size (upto 3x faster and 9x memory efficient ๐)
-
Choosing chunk size (i.e. how much to compress?) allows CATs to interpolate between compressed (fast) and dense (slow) transformer directly at test-time โฐ, trading off quality for efficiency.
-
We take the core concepts and instantiate CAT as a layer which can be swapped in any sequence model as a drop-in replacement, replacing dense attention. This can unlock lots of interesting possibilities starting with creating hybrid as well as adaptive architectures that mixes CAT layers alongside dense attention, or perhaps even linear attention.
โ ๏ธ We will be releasing our pre-trained CAT models soon! ๐
Here are some things to keep in mind:
transformer.pycontains a fast implementation for transformer++. Highly inspired from theLightning-AI/litgptrepo. To make this implementation efficient, it uses triton kernels fromlinkedin/Liger-Kernelrepo. CAT's implementation directly imports components from here since it builds on vanilla transformer abstractions.cat_transformer.pycontains a scalable implementation for CATs. We provide a simple usage that can be directly used in most training scripts. This supports fixed chunk sizes only.cat_transformer_adaptive.pycontains an implementation for adaptive CATs that can work with multiple chunk sizes, thereby unlocking controllable efficiency.
Please refer to usages below for more details.
โ ๏ธ Note that according to the paper, the decoder in CAT should be made more expressive (contain more parameters) in order to accurately decode from the compressed chunk representations (refer to below usage to correctly instantiate a CAT). This does not mean CATs are inefficient; in fact, due to compression, CATs are much more efficient than vanilla transformers in terms of throughput and total memory usage.
Usage for CATs with fixed chunk size
Refer to cat_transformer.py
device = "cuda" if torch.cuda.is_available() else "cpu"
# below assumes that one wishes to instantiate a CAT that matches
# a vanilla transformer containing 12 layers, and hidden size of 768
dim = 768
n_head = 12
num_layers = 12
# this is the hidden size of decoder, which is recommended to be 2*dim
# however, it can be 1.5*dim, or 1.25*dim depending on the task
# dim_fx means the size of the compressed chunk representations (f(c)'s), which
# is same as hidden size of the decoder
decoder_dim = 2 * dim # hidden size of the decoder
dim_fx = decoder_dim # size of compressed chunk representations
n_head_decoder = 2 * n_head # increase heads too proportionally
block_size = 2048 # context length
chunk_size = 8 # chunk size
# instantiate the model
compressor_config = CAT_Config(dim=dim, n_head=n_head, dim_fx=dim_fx, block_size=block_size, chunk_size=chunk_size, n_layer=(num_layers // 4)) # layers are defined according to the paper, but one may use lower number of layers in the compressor
decoder_config = CAT_Config(dim=decoder_dim, n_head=n_head_decoder, block_size=block_size, chunk_size=chunk_size, n_layer=num_layers)
model = CAT_Transformer(decoder_config, compressor_config)
model = model.to(device=device)
model.setup_cache(device=device)
# do forward pass
input_ids = torch.randint(0, decoder_config.vocab_size, (4, block_size), device=device)
logits = model(input_ids)
# do stuff with logits ...Benchmark CATs
Refer to benchmark.py to measure generation throughput and memory usage of CATs.
Usage for adaptive CATs
Refer to cat_transformer_adaptive.py
device = "cuda" if torch.cuda.is_available() else "cpu"
# below assumes that one wishes to instantiate a CAT that matches
# a vanilla transformer containing 12 layers, and hidden size of 768
dim = 768
num_layers = 4
n_head = 12
# this is the hidden size of decoder, which is recommended to be 2*dim
# however, it can be 1.5*dim, or 1.25*dim depending on the task
# dim_fx means the size of the compressed chunk representations (f(c)'s), which
# is same as hidden size of the decoder
decoder_dim = 2 * dim # hidden size of the decoder
dim_fx = decoder_dim # size of compressed chunk representations
n_head_decoder = 2 * n_head # increase heads too proportionally
block_size = 2048 # context length
chunk_size = 32 # chunk size
# instantiate the model
compressor_config = CAT_Config(dim=dim, n_head=n_head, dim_fx=dim_fx, block_size=block_size, chunk_size=chunk_size, n_layer=(num_layers // 4)) # layers are defined according to the paper, but one may use lower number of layers in the compressor
decoder_config = CAT_Config(dim=decoder_dim, n_head=n_head_decoder, block_size=block_size, chunk_size=chunk_size, n_layer=num_layers)
model = CAT_Transformer(decoder_config, compressor_config)
model = model.to(device=device)
model.setup_cache(device=device)
# do forward pass
input_ids = torch.randint(0, decoder_config.vocab_size, (4, block_size), device=device)
print("input_ids shape:", input_ids.shape)
# choose which chunk size to use for this forward pass
# must be power of 2, and and less than or equal to chunk_size
# only powers of two supported for now
cur_chunk_size_power = 4 # corresponds to chunk size of 16 (2^4)
logits = model(input_ids, chunk_size_power=cur_chunk_size_power)
print("logits shape:", logits.shape)
# do stuff with logits ...Usage for CAT as a drop-in replacement layer
Refer to cat_layer.py
device = "cuda" if torch.cuda.is_available() else "cpu"
# simple test
batch_size = 4
seq_len = 2048
chunk_size = 8
config = CAT_Config(
dim=768,
n_head=16,
chunk_size=chunk_size,
# again, needs 2*dim for accurate decoding from compressed chunk representations
dim_fx=2 * 768,
block_size=seq_len,
# right now, every layer is a CAT layer
# but the implementation can be easily modified to create hybrid and adaptive architectures :)
n_layer=12,
)
model = CAT_Layer_Transformer(config)
model.setup_cache(device=device)
model.to(device)
x = torch.randint(0, config.padded_vocab_size, (batch_size, seq_len), device=device)
logits = model(x)
# do stuff with logits ...Here are the packages that we used to run our code:
torch==2.5.1+cu121
liger-kernelThis implementation borrows heavily from the following repositories:
Feel free to open issues for any questions or clarifications regarding the code or paper. Thanksss!
Consider giving this repo a โญ if you found it useful ๐



