Skip to content

Sliding window attention for long-context training #42

@mmshad

Description

@mmshad

Context

Sliding window attention limits each token to attend to the last W previous positions instead of the full sequence. Primary benefits: attention compute drops from O(S²) to O(S·W), and KV cache during inference shrinks from O(S) to O(W). Useful for long-context training (32K+ sequences) where full attention is compute-prohibitive.

Used by Mistral, Mixtral, Gemma-2, and similar architectures — some apply sliding window on every layer, others alternate with full attention.

Current state

Attention in kempnerforge/model/attention.py always uses full causal attention via PyTorch SDPA. ModelConfig has no sliding window field. KVCache pre-allocates max_seq_len slots and grows linearly with the full sequence.

PyTorch SDPA does not have a native window_size parameter, so sliding window requires an explicit banded-causal attn_mask.

What needs to change

  1. Config: Add sliding_window: int | None = None to ModelConfig in kempnerforge/config/model.py. None keeps current behavior (full attention).

  2. Threading: Pass sliding_window through TransformerBlock into Attention.__init__(); store as self.sliding_window.

  3. Masking in Attention.forward(): When sliding_window is set, construct a banded causal mask and pass as attn_mask. When doc_ids is also present, intersect the band with the existing block-diagonal document mask — each token attends to the last W tokens within the same document only.

  4. KV cache: When sliding_window is set, KVCache only needs the last W positions. Convert to a circular buffer of size W instead of pre-allocating max_seq_len. Reduces inference memory from O(max_seq_len) to O(W).

Config example

[model]
sliding_window = 4096  # each token attends to last 4096 positions

MFU impact

kempnerforge/metrics/mfu.py uses 12 * L * D * S for the attention term in both _dense_flops_per_token and _moe_flops_per_token. When sliding_window is set, this term should use min(S, W) — otherwise MFU will be under-reported for sliding-window models.

Testing

  • Output matches full attention when W >= seq_len
  • Tokens beyond window distance receive zero attention weight
  • Works with GQA, packed sequences (doc_ids), and KV cache
  • MFU calculation uses min(S, W) when sliding window is set

Priority

Medium. Unlocks long-context training on memory-constrained hardware. Independent of FA3 upstream tracking.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions