Skip to content

Add embedding_scale and final_logit_softcap (Gemma 4 prep)#492

Open
jlamypoirier wants to merge 1 commit intomainfrom
worktree-gemma
Open

Add embedding_scale and final_logit_softcap (Gemma 4 prep)#492
jlamypoirier wants to merge 1 commit intomainfrom
worktree-gemma

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

  • embedding_scale (LanguageModelEmbeddingsConfig): multiplicative scale applied to word embeddings after lookup. Gemma 4 uses sqrt(hidden_size). Zero overhead for the default value of 1.0 via a compile-time branch inside @torch.compile-decorated _forward. Necessary as a runtime op (not a weight init change) because tied embeddings share the weight with the LM head — baking the scale into weights would also scale logits.
  • final_logit_softcap (LanguageModelHeadConfig): applies tanh(logits / cap) * cap before the loss. Gemma 4 uses cap=30. Forward and backward are each @torch.compile-decorated for op fusion. Gradient propagates through the Jacobian (1 - (softcapped / cap)²) before the output-linear backward.

Tests

  • New tests/layers/test_embedding.py: generic parametrized embedding layer test — 3 base cases (default, with_padding, with_position_embeddings) × 4 variants (default, bfloat16, full_precision_residual, embedding_scale=2.0) = 12 cases.
  • Adds final_logit_softcap=2.0 case to test_lm_head.py (4 cases: plain, split, masked, masked+split).

Test plan

  • pytest -v tests/layers/test_embedding.py — 12 passed
  • pytest -v tests/layers/test_lm_head.py — 56 passed

🤖 Generated with Claude Code

- `LanguageModelEmbeddingsConfig.embedding_scale`: multiplicative scale
  applied to word embeddings after lookup (Gemma 4 uses sqrt(hidden_size)).
  Zero overhead for the default value of 1.0 via a compile-time branch in
  the @torch.compile-decorated _forward.
- `LanguageModelHeadConfig.final_logit_softcap`: applies
  tanh(logits / cap) * cap before the loss. Forward and backward are
  each wrapped in @torch.compile for op fusion. Gradient back-propagates
  through the Jacobian (1 - (softcapped / cap)^2) before the output
  linear backward.
- New test_embedding.py: generic parametrized embedding layer test
  covering scale, dtype, full_precision_residual, position embeddings,
  and padding (3 base cases x 4 variants).
- Adds final_logit_softcap case to test_lm_head.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant