Skip to content

Incorrect RMSNorm #4

@arunmallya

Description

@arunmallya

The RMSNorm implementation in this codebase in wrong as it computes the RMS over the (T, D) dimensions instead of the (D) dimension. Assume input x is of shape (B, T, D).

The current code does this:

# x is (B, T, D).
ff_rms = torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5  # (B,).
raw = x / ff_rms.unsqueeze(-1).unsqueeze(-1)  # (B, 1, 1).

The original RMSNorm is here - https://github.com/meta-llama/llama/blob/main/llama/model.py#L34-L77

x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

The correct version using Frobenius norm would be:

ff_rms = torch.linalg.norm(x, dim=-1, keepdims=True) / math.sqrt(x.shape[-1])  # (B, T, 1).
raw = x / (ff_rms + eps)

Normalization should be per-token, not per-sequence.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions