Context
Sliding window attention limits each token to attend to the last W previous positions instead of the full sequence. Primary benefits: attention compute drops from O(S²) to O(S·W), and KV cache during inference shrinks from O(S) to O(W). Useful for long-context training (32K+ sequences) where full attention is compute-prohibitive.
Used by Mistral, Mixtral, Gemma-2, and similar architectures — some apply sliding window on every layer, others alternate with full attention.
Current state
Attention in kempnerforge/model/attention.py always uses full causal attention via PyTorch SDPA. ModelConfig has no sliding window field. KVCache pre-allocates max_seq_len slots and grows linearly with the full sequence.
PyTorch SDPA does not have a native window_size parameter, so sliding window requires an explicit banded-causal attn_mask.
What needs to change
-
Config: Add sliding_window: int | None = None to ModelConfig in kempnerforge/config/model.py. None keeps current behavior (full attention).
-
Threading: Pass sliding_window through TransformerBlock into Attention.__init__(); store as self.sliding_window.
-
Masking in Attention.forward(): When sliding_window is set, construct a banded causal mask and pass as attn_mask. When doc_ids is also present, intersect the band with the existing block-diagonal document mask — each token attends to the last W tokens within the same document only.
-
KV cache: When sliding_window is set, KVCache only needs the last W positions. Convert to a circular buffer of size W instead of pre-allocating max_seq_len. Reduces inference memory from O(max_seq_len) to O(W).
Config example
[model]
sliding_window = 4096 # each token attends to last 4096 positions
MFU impact
kempnerforge/metrics/mfu.py uses 12 * L * D * S for the attention term in both _dense_flops_per_token and _moe_flops_per_token. When sliding_window is set, this term should use min(S, W) — otherwise MFU will be under-reported for sliding-window models.
Testing
- Output matches full attention when
W >= seq_len
- Tokens beyond window distance receive zero attention weight
- Works with GQA, packed sequences (
doc_ids), and KV cache
- MFU calculation uses
min(S, W) when sliding window is set
Priority
Medium. Unlocks long-context training on memory-constrained hardware. Independent of FA3 upstream tracking.
Context
Sliding window attention limits each token to attend to the last W previous positions instead of the full sequence. Primary benefits: attention compute drops from O(S²) to O(S·W), and KV cache during inference shrinks from O(S) to O(W). Useful for long-context training (32K+ sequences) where full attention is compute-prohibitive.
Used by Mistral, Mixtral, Gemma-2, and similar architectures — some apply sliding window on every layer, others alternate with full attention.
Current state
Attention in
kempnerforge/model/attention.pyalways uses full causal attention via PyTorch SDPA.ModelConfighas no sliding window field.KVCachepre-allocatesmax_seq_lenslots and grows linearly with the full sequence.PyTorch SDPA does not have a native
window_sizeparameter, so sliding window requires an explicit banded-causalattn_mask.What needs to change
Config: Add
sliding_window: int | None = NonetoModelConfiginkempnerforge/config/model.py.Nonekeeps current behavior (full attention).Threading: Pass
sliding_windowthroughTransformerBlockintoAttention.__init__(); store asself.sliding_window.Masking in
Attention.forward(): Whensliding_windowis set, construct a banded causal mask and pass asattn_mask. Whendoc_idsis also present, intersect the band with the existing block-diagonal document mask — each token attends to the last W tokens within the same document only.KV cache: When
sliding_windowis set,KVCacheonly needs the last W positions. Convert to a circular buffer of size W instead of pre-allocatingmax_seq_len. Reduces inference memory from O(max_seq_len) to O(W).Config example
MFU impact
kempnerforge/metrics/mfu.pyuses12 * L * D * Sfor the attention term in both_dense_flops_per_tokenand_moe_flops_per_token. Whensliding_windowis set, this term should usemin(S, W)— otherwise MFU will be under-reported for sliding-window models.Testing
W >= seq_lendoc_ids), and KV cachemin(S, W)when sliding window is setPriority
Medium. Unlocks long-context training on memory-constrained hardware. Independent of FA3 upstream tracking.