Skip to content

rajesh-lab/cat-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

22 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Compress And Attend Transformers (CATs)

This repository provides single-file hackable, scalable and efficient ๐Ÿš€ pure PyTorch implementation for CATs.

Attention and Compression is all you need for Controllably Efficient Language Models
Jatin Prakash, Aahlad Puli, Rajesh Ranganath
New York University

Static Badge Twitter

  • simple architecture that employs two simple well-known ingredients: dense attention and compression.
  • provides a controllable knob at test-time to trade-off quality for efficiency, interpolating between dense transformer and efficient alternatives, all without any retraining.
  • can be used as a drop-in replacement for dense/linear attention layers in any architecture to create controllably ๐Ÿ•น๏ธ efficient architectures.

cat_diagram

Overview

  • CATs model chunks of tokens given compressed representations of past chunks in the sequence ๐Ÿ˜ธ.

  • No need to heuristically define attention masks; no need for handcrafted and complex recurrent state update rules; no need to carefully compose with attention at specific layers to have a capable architecture ๐Ÿ’†โ€โ™€๏ธ๐Ÿ˜Œ.

The troubled cat ๐Ÿ˜ฟ below describes the overwhelming feeling of designing an efficient architecture

troubled_cat

  • Due to compression resulting in a reduced sequence length, compute FLOPs & KV-cache memory diminish by a factor of chunk size (upto 3x faster and 9x memory efficient ๐Ÿš€)

throughput

  • Choosing chunk size (i.e. how much to compress?) allows CATs to interpolate between compressed (fast) and dense (slow) transformer directly at test-time โฐ, trading off quality for efficiency.

  • We take the core concepts and instantiate CAT as a layer which can be swapped in any sequence model as a drop-in replacement, replacing dense attention. This can unlock lots of interesting possibilities starting with creating hybrid as well as adaptive architectures that mixes CAT layers alongside dense attention, or perhaps even linear attention.

Usage

โš ๏ธ We will be releasing our pre-trained CAT models soon! ๐Ÿ™‚

Here are some things to keep in mind:

  • transformer.py contains a fast implementation for transformer++. Highly inspired from the Lightning-AI/litgpt repo. To make this implementation efficient, it uses triton kernels from linkedin/Liger-Kernel repo. CAT's implementation directly imports components from here since it builds on vanilla transformer abstractions.
  • cat_transformer.py contains a scalable implementation for CATs. We provide a simple usage that can be directly used in most training scripts. This supports fixed chunk sizes only.
  • cat_transformer_adaptive.py contains an implementation for adaptive CATs that can work with multiple chunk sizes, thereby unlocking controllable efficiency.

Please refer to usages below for more details.

โš ๏ธ Note that according to the paper, the decoder in CAT should be made more expressive (contain more parameters) in order to accurately decode from the compressed chunk representations (refer to below usage to correctly instantiate a CAT). This does not mean CATs are inefficient; in fact, due to compression, CATs are much more efficient than vanilla transformers in terms of throughput and total memory usage.

Usage for CATs with fixed chunk size

Refer to cat_transformer.py

device = "cuda" if torch.cuda.is_available() else "cpu"

# below assumes that one wishes to instantiate a CAT that matches
# a vanilla transformer containing 12 layers, and hidden size of 768
dim = 768
n_head = 12
num_layers = 12

# this is the hidden size of decoder, which is recommended to be 2*dim
# however, it can be 1.5*dim, or 1.25*dim depending on the task
# dim_fx means the size of the compressed chunk representations (f(c)'s), which
# is same as hidden size of the decoder
decoder_dim = 2 * dim # hidden size of the decoder
dim_fx = decoder_dim # size of compressed chunk representations
n_head_decoder = 2 * n_head # increase heads too proportionally

block_size = 2048 # context length
chunk_size = 8 # chunk size

# instantiate the model
compressor_config = CAT_Config(dim=dim, n_head=n_head, dim_fx=dim_fx, block_size=block_size, chunk_size=chunk_size, n_layer=(num_layers // 4)) # layers are defined according to the paper, but one may use lower number of layers in the compressor
decoder_config = CAT_Config(dim=decoder_dim, n_head=n_head_decoder, block_size=block_size, chunk_size=chunk_size, n_layer=num_layers)
model = CAT_Transformer(decoder_config, compressor_config)
model = model.to(device=device)
model.setup_cache(device=device)

# do forward pass
input_ids = torch.randint(0, decoder_config.vocab_size, (4, block_size), device=device)
logits = model(input_ids)
# do stuff with logits ...
Benchmark CATs

Refer to benchmark.py to measure generation throughput and memory usage of CATs.

Usage for adaptive CATs

Refer to cat_transformer_adaptive.py

device = "cuda" if torch.cuda.is_available() else "cpu"

  # below assumes that one wishes to instantiate a CAT that matches
  # a vanilla transformer containing 12 layers, and hidden size of 768
  dim = 768
  num_layers = 4
  n_head = 12

  # this is the hidden size of decoder, which is recommended to be 2*dim
  # however, it can be 1.5*dim, or 1.25*dim depending on the task
  # dim_fx means the size of the compressed chunk representations (f(c)'s), which
  # is same as hidden size of the decoder
  decoder_dim = 2 * dim # hidden size of the decoder
  dim_fx = decoder_dim # size of compressed chunk representations
  n_head_decoder = 2 * n_head # increase heads too proportionally

  block_size = 2048 # context length
  chunk_size = 32 # chunk size

  # instantiate the model
  compressor_config = CAT_Config(dim=dim, n_head=n_head, dim_fx=dim_fx, block_size=block_size, chunk_size=chunk_size, n_layer=(num_layers // 4)) # layers are defined according to the paper, but one may use lower number of layers in the compressor
  decoder_config = CAT_Config(dim=decoder_dim, n_head=n_head_decoder, block_size=block_size, chunk_size=chunk_size, n_layer=num_layers)
  model = CAT_Transformer(decoder_config, compressor_config)
  model = model.to(device=device)
  model.setup_cache(device=device)

  # do forward pass
  input_ids = torch.randint(0, decoder_config.vocab_size, (4, block_size), device=device)
  print("input_ids shape:", input_ids.shape)

  # choose which chunk size to use for this forward pass
  # must be power of 2, and and less than or equal to chunk_size
  # only powers of two supported for now
  cur_chunk_size_power = 4 # corresponds to chunk size of 16 (2^4)

  logits = model(input_ids, chunk_size_power=cur_chunk_size_power)

  print("logits shape:", logits.shape)
  # do stuff with logits ...
Usage for CAT as a drop-in replacement layer

Refer to cat_layer.py

device = "cuda" if torch.cuda.is_available() else "cpu"

# simple test
batch_size = 4
seq_len = 2048
chunk_size = 8

config = CAT_Config(
    dim=768,
    n_head=16,
    chunk_size=chunk_size,

    # again, needs 2*dim for accurate decoding from compressed chunk representations
    dim_fx=2 * 768, 

    block_size=seq_len,

    # right now, every layer is a CAT layer
    # but the implementation can be easily modified to create hybrid and adaptive architectures :)
    n_layer=12,
)

model = CAT_Layer_Transformer(config)
model.setup_cache(device=device)
model.to(device)

x = torch.randint(0, config.padded_vocab_size, (batch_size, seq_len), device=device)

logits = model(x)

# do stuff with logits ...

Installation

Here are the packages that we used to run our code:

torch==2.5.1+cu121
liger-kernel

Acknowledgements

This implementation borrows heavily from the following repositories:

Support

Feel free to open issues for any questions or clarifications regarding the code or paper. Thanksss!

Consider giving this repo a โญ if you found it useful ๐Ÿ˜Š

About

Compress and Attend Transformers (CATs) ๐Ÿ˜ธ

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages