Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions nanochat/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
- Flash Attention 3 integration
"""

from functools import partial
from dataclasses import dataclass

import torch
Expand All @@ -22,6 +21,7 @@
from nanochat.common import get_dist_info, print0
from nanochat.optim import MuonAdamW, DistMuonAdamW


# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn

Expand Down Expand Up @@ -176,15 +176,60 @@ def __init__(self, config, pad_vocab_size_to=64):
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
# Precompute a reasonably large RoPE cache up front (cheap relative to model weights).
# The cache may also grow lazily in forward() if generation exceeds this length.
self.rotary_seq_len = config.sequence_len * 10
self.max_rotary_seq_len = self.rotary_seq_len

head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)

def _ensure_rope_cache(self, needed_seq_len: int):
"""
Ensure rotary embedding cache (cos/sin) is long enough for absolute positions [0, needed_seq_len).

We grow lazily to avoid crashes for long prompts / long KV-cache generation.
Growth is amortized by rounding up to the next power of two.

Growth is bounded by self.max_rotary_seq_len to avoid unbounded memory usage.
"""
cur_len = self.cos.size(1)
if needed_seq_len <= cur_len:
return

if needed_seq_len > self.max_rotary_seq_len:
raise RuntimeError(
f"RoPE cache request exceeds max_rotary_seq_len: need {needed_seq_len}, "
f"have {cur_len}, cap {self.max_rotary_seq_len}. "
"Increase max_rotary_seq_len for longer-context generation."
)

# Safety: mutating buffers during torch.compile tracing is unsafe.
import torch._dynamo
if torch._dynamo.is_compiling():
raise RuntimeError(
f"RoPE cache too small during torch.compile (need {needed_seq_len}, have {cur_len}). "
"Increase initial rotary_seq_len/max_rotary_seq_len or avoid compiled generation."
)

# Next power-of-two >= needed_seq_len (amortized growth), bounded by cap
new_len = min(self.max_rotary_seq_len, 1 << (needed_seq_len - 1).bit_length())

head_dim = self.config.n_embd // self.config.n_head
device = self.cos.device
cos, sin = self._precompute_rotary_embeddings(seq_len=new_len, head_dim=head_dim, device=device)

# Preserve dtype/device invariants (precompute already returns bf16, but keep explicit)
cos = cos.to(dtype=self.cos.dtype, device=device)
sin = sin.to(dtype=self.sin.dtype, device=device)

# Overwrite existing registered buffers (persistent=False remains from initial registration)
self.cos = cos
self.sin = sin
self.rotary_seq_len = new_len

@torch.no_grad()
def init_weights(self):
"""
Expand Down Expand Up @@ -387,14 +432,16 @@ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02

def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()

# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
T0 = 0 if kv_cache is None else kv_cache.get_pos()

# Ensure RoPE buffers cover absolute positions [T0, T0+T)
self._ensure_rope_cache(T0 + T)

assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

# If kv cache exists, offset RoPE by current absolute position
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]

# Forward the trunk of the Transformer
x = self.transformer.wte(idx) # embed current token
Expand Down