-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
The RMSNorm implementation in this codebase in wrong as it computes the RMS over the (T, D) dimensions instead of the (D) dimension. Assume input x is of shape (B, T, D).
The current code does this:
# x is (B, T, D).
ff_rms = torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5 # (B,).
raw = x / ff_rms.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1).The original RMSNorm is here - https://github.com/meta-llama/llama/blob/main/llama/model.py#L34-L77
x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)The correct version using Frobenius norm would be:
ff_rms = torch.linalg.norm(x, dim=-1, keepdims=True) / math.sqrt(x.shape[-1]) # (B, T, 1).
raw = x / (ff_rms + eps)Normalization should be per-token, not per-sequence.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels