Skip to content

Commit 0909d36

Browse files
authored
Merge pull request #38 from curt-tigges/feature/tied-decoders
Feature/tied decoders
2 parents 3ed14c7 + f02bde6 commit 0909d36

18 files changed

+2445
-68
lines changed

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ This library is intended for the training and analysis of cross-layer sparse cod
1111

1212
A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for `gpt2-small`). This implementation focuses on the core functionality needed to train and use CLTs, leveraging `nnsight` for model introspection and `datasets` for data handling.
1313

14+
The library now supports **tied decoders**, which can significantly reduce the number of parameters by sharing decoder weights across layers. Instead of training separate decoders for each (source, destination) pair, tied decoders use either:
15+
- **Per-source tying**: One decoder per source layer, shared across all destination layers
16+
- **Per-target tying**: One decoder per destination layer, shared across all source layers
17+
1418
Training a CLT involves the following steps:
1519
1. Pre-generate activations with `scripts/generate_activations` (though an implementation of `StreamingActivationStore` is on the way).
1620
2. Train a CLT (start with an expansion factor of at least `32`) using this data. Metrics can be logged to WandB. NMSE should get below `0.25`, or ideally even below `0.10`. As mentioned above, I recommend `BatchTopK` training, and suggest keeping `K` low--`200` is a good place to start.
@@ -85,6 +89,16 @@ Key configuration parameters are mapped to config classes via script arguments:
8589
- `relu`: Standard ReLU activation.
8690
- `batchtopk`: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted to `jumprelu`.
8791
- `topk`: Selects top K features per token (row-wise top-k).
92+
93+
**Decoder Tying Options** (`--decoder-tying`):
94+
- `none` (default): Traditional untied decoders - separate decoder for each (source, destination) layer pair
95+
- `per_source`: Share decoder weights per source layer - each source layer has one decoder used for all destinations
96+
- `per_target`: Share decoder weights per destination layer - each destination layer has one decoder that combines features from all source layers
97+
98+
**Additional Tied Decoder Features**:
99+
- `--enable-feature-offset`: Add learnable per-feature bias terms
100+
- `--enable-feature-scale`: Add learnable per-feature scaling
101+
- `--skip-connection`: Enable skip connections from source inputs to decoder outputs
88102
- **TrainingConfig**: `--learning-rate`, `--training-steps`, `--train-batch-size-tokens`, `--activation-source`, `--activation-path` (for `local_manifest`), remote config fields (for `remote`, e.g. `--server-url`, `--dataset-id`), `--normalization-method`, `--sparsity-lambda`, `--preactivation-coef`, `--optimizer`, `--lr-scheduler`, `--log-interval`, `--eval-interval`, `--checkpoint-interval`, `--dead-feature-window`, WandB settings (`--enable-wandb`, `--wandb-project`, etc.).
89103

90104
### Single GPU Training Examples
@@ -139,6 +153,38 @@ python scripts/train_clt.py \\
139153
# Add other arguments as needed
140154
```
141155

156+
**Example: Training with Tied Decoders**
157+
158+
Tied decoders can significantly reduce the parameter count while maintaining performance. Here's an example using per-source tying:
159+
160+
```bash
161+
python scripts/train_clt.py \
162+
--activation-source local_manifest \
163+
--activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \
164+
--output-dir ./clt_output_tied \
165+
--model-name gpt2 \
166+
--num-features 6144 \
167+
--decoder-tying per_source \
168+
--enable-feature-scale \
169+
--skip-connection \
170+
--activation-fn batchtopk \
171+
--batchtopk-k 256 \
172+
--learning-rate 3e-4 \
173+
--training-steps 100000 \
174+
--train-batch-size-tokens 8192 \
175+
--sparsity-lambda 1e-3 \
176+
--log-interval 100 \
177+
--eval-interval 1000 \
178+
--checkpoint-interval 5000 \
179+
--enable-wandb --wandb-project clt_tied_training
180+
```
181+
182+
This configuration:
183+
- Uses `per_source` tying: 12 decoders instead of 78 for gpt2-small
184+
- Enables feature scaling for better expressiveness
185+
- Includes skip connections to preserve input information
186+
- Uses BatchTopK with k=256 for training (can be converted to JumpReLU later)
187+
142188
### Multi-GPU Training (Tensor Parallelism)
143189

144190
This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (`torch.distributed`). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation.

clt/config/clt_config.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CLTConfig:
1515
num_layers: int # Number of transformer layers
1616
d_model: int # Dimension of model's hidden state
1717
model_name: Optional[str] = None # Optional name for the underlying model
18-
normalization_method: Literal["auto", "estimated_mean_std", "none"] = (
18+
normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = (
1919
"none" # How activations were normalized during training
2020
)
2121
activation_fn: Literal["jumprelu", "relu", "batchtopk", "topk"] = "jumprelu"
@@ -27,20 +27,28 @@ class CLTConfig:
2727
topk_k: Optional[float] = None # Number or fraction of features to keep per token for TopK.
2828
# If < 1, treated as fraction. If >= 1, treated as int count.
2929
topk_straight_through: bool = True # Whether to use straight-through estimator for TopK.
30+
# Top-K mode selection
31+
topk_mode: Literal["global", "per_layer"] = "global" # How to apply top-k selection
3032
clt_dtype: Optional[str] = None # Optional dtype for the CLT model itself (e.g., "float16")
3133
expected_input_dtype: Optional[str] = None # Expected dtype of input activations
3234
mlp_input_template: Optional[str] = None # Module path template for MLP input activations
3335
mlp_output_template: Optional[str] = None # Module path template for MLP output activations
3436
tl_input_template: Optional[str] = None # TransformerLens hook point pattern before MLP
3537
tl_output_template: Optional[str] = None # TransformerLens hook point pattern after MLP
3638
# context_size: Optional[int] = None
39+
40+
# Tied decoder configuration
41+
decoder_tying: Literal["none", "per_source", "per_target"] = "none" # Decoder weight sharing strategy
42+
enable_feature_offset: bool = False # Enable per-feature bias (feature_offset)
43+
enable_feature_scale: bool = False # Enable per-feature scale (feature_scale)
44+
skip_connection: bool = False # Enable skip connection from input to output
3745

3846
def __post_init__(self):
3947
"""Validate configuration parameters."""
4048
assert self.num_features > 0, "Number of features must be positive"
4149
assert self.num_layers > 0, "Number of layers must be positive"
4250
assert self.d_model > 0, "Model dimension must be positive"
43-
valid_norm_methods = ["auto", "estimated_mean_std", "none"]
51+
valid_norm_methods = ["none", "mean_std", "sqrt_d_model"]
4452
assert (
4553
self.normalization_method in valid_norm_methods
4654
), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}"
@@ -60,6 +68,12 @@ def __post_init__(self):
6068
raise ValueError("topk_k must be specified for TopK activation function.")
6169
if self.topk_k is not None and self.topk_k <= 0:
6270
raise ValueError("topk_k must be positive if specified.")
71+
72+
# Validate decoder tying configuration
73+
valid_decoder_tying = ["none", "per_source", "per_target"]
74+
assert (
75+
self.decoder_tying in valid_decoder_tying
76+
), f"Invalid decoder_tying: {self.decoder_tying}. Must be one of {valid_decoder_tying}"
6377

6478
@classmethod
6579
def from_json(cls: Type[C], json_path: str) -> C:
@@ -73,6 +87,30 @@ def from_json(cls: Type[C], json_path: str) -> C:
7387
"""
7488
with open(json_path, "r") as f:
7589
config_dict = json.load(f)
90+
91+
# Handle backward compatibility for old configs
92+
if "decoder_tying" not in config_dict:
93+
config_dict["decoder_tying"] = "none" # Default to original behavior
94+
if "enable_feature_offset" not in config_dict:
95+
config_dict["enable_feature_offset"] = False
96+
if "enable_feature_scale" not in config_dict:
97+
config_dict["enable_feature_scale"] = False
98+
99+
# Handle backwards compatibility for old normalization methods
100+
if "normalization_method" in config_dict:
101+
old_method = config_dict["normalization_method"]
102+
# Map old values to new ones
103+
if old_method in ["auto", "estimated_mean_std"]:
104+
config_dict["normalization_method"] = "mean_std"
105+
elif old_method in ["auto_sqrt_d_model", "estimated_mean_std_sqrt_d_model"]:
106+
config_dict["normalization_method"] = "sqrt_d_model"
107+
108+
# Handle old sqrt_d_model_normalize flag
109+
if "sqrt_d_model_normalize" in config_dict:
110+
sqrt_normalize = config_dict.pop("sqrt_d_model_normalize")
111+
if sqrt_normalize:
112+
config_dict["normalization_method"] = "sqrt_d_model"
113+
76114
return cls(**config_dict)
77115

78116
def to_json(self, json_path: str) -> None:
@@ -108,11 +146,11 @@ class TrainingConfig:
108146
debug_anomaly: bool = False
109147

110148
# Normalization parameters
111-
normalization_method: Literal["auto", "estimated_mean_std", "none"] = "auto"
112-
# 'auto': Use pre-calculated from mapped store, or estimate for streaming store.
113-
# 'estimated_mean_std': Always estimate for streaming store (ignored for mapped).
114-
# 'none': Disable normalization.
115-
normalization_estimation_batches: int = 50 # Batches for normalization estimation
149+
normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = "mean_std"
150+
# 'none': No normalization.
151+
# 'mean_std': Standard (x - mean) / std normalization using pre-calculated stats.
152+
# 'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization.
153+
normalization_estimation_batches: int = 50 # Batches for normalization estimation (if needed)
116154

117155
# --- Activation Store Source --- #
118156
activation_source: Literal["local_manifest", "remote"] = "local_manifest"
@@ -221,6 +259,12 @@ def __post_init__(self):
221259
assert (
222260
0.0 <= self.sparsity_lambda_delay_frac < 1.0
223261
), "sparsity_lambda_delay_frac must be between 0.0 (inclusive) and 1.0 (exclusive)"
262+
263+
# Validate normalization method
264+
valid_norm_methods = ["none", "mean_std", "sqrt_d_model"]
265+
assert (
266+
self.normalization_method in valid_norm_methods
267+
), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}"
224268

225269

226270
@dataclass

clt/models/clt.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,24 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
179179
)
180180
return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype)
181181

182-
def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor:
183-
return self.decoder_module.decode(a, layer_idx)
182+
183+
def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor:
184+
return self.decoder_module.decode(a, layer_idx, source_inputs)
184185

185186
def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]:
186187
activations = self.get_feature_activations(inputs)
188+
189+
# Note: feature affine transformations are now applied in the decoder
187190

188191
reconstructions = {}
189192
for layer_idx in range(self.config.num_layers):
190193
relevant_activations = {k: v for k, v in activations.items() if k <= layer_idx and v.numel() > 0}
191194
if layer_idx in inputs and relevant_activations:
192-
reconstructions[layer_idx] = self.decode(relevant_activations, layer_idx)
195+
# Pass source inputs for EleutherAI-style skip connections
196+
source_inputs = {k: inputs[k] for k in range(layer_idx + 1) if k in inputs} if self.config.skip_connection else None
197+
reconstruction = self.decode(relevant_activations, layer_idx, source_inputs)
198+
199+
reconstructions[layer_idx] = reconstruction
193200
elif layer_idx in inputs:
194201
batch_size = 0
195202
input_tensor = inputs[layer_idx]
@@ -216,6 +223,17 @@ def get_feature_activations(self, inputs: Dict[int, torch.Tensor]) -> Dict[int,
216223
processed_inputs[layer_idx] = x_orig.to(device=self.device, dtype=self.dtype)
217224

218225
if self.config.activation_fn == "batchtopk" or self.config.activation_fn == "topk":
226+
# Check if we should use per-layer mode
227+
if self.config.topk_mode == "per_layer":
228+
# Use per-layer top-k by calling encode on each layer
229+
activations = {}
230+
for layer_idx in sorted(processed_inputs.keys()):
231+
x_input = processed_inputs[layer_idx]
232+
act = self.encode(x_input, layer_idx)
233+
activations[layer_idx] = act
234+
return activations
235+
236+
# Otherwise use global top-k
219237
preactivations_dict, _ = self._encode_all_layers(processed_inputs)
220238
if not preactivations_dict:
221239
activations = {}
@@ -325,3 +343,83 @@ def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None:
325343
if not hasattr(self, "theta_manager") or self.theta_manager is None:
326344
raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.")
327345
self.theta_manager.log_threshold = new_param
346+
347+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
348+
"""Load state dict with backward compatibility for old checkpoints.
349+
350+
Handles:
351+
1. Old untied decoder format -> new tied/untied format
352+
2. Missing theta_bias/theta_scale parameters
353+
3. Missing per_target_scale/per_target_bias parameters
354+
"""
355+
# Check if this is an old checkpoint by looking for decoder keys
356+
old_format_decoder_keys = [k for k in state_dict.keys() if 'decoders.' in k and '->' in k]
357+
is_old_checkpoint = len(old_format_decoder_keys) > 0
358+
359+
if is_old_checkpoint and self.config.decoder_tying == "per_source":
360+
logger.warning(
361+
"Loading old untied decoder checkpoint into tied decoder model. "
362+
"This will use weights from the first target layer for each source layer."
363+
)
364+
365+
# Convert old decoder weights to tied format
366+
# For each source layer, use the weights from src->src decoder
367+
new_state_dict = {}
368+
for key, value in state_dict.items():
369+
if 'decoders.' in key and '->' in key:
370+
# Extract source and target layer indices
371+
# Key format: "decoder_module.decoders.{src}->{tgt}.weight" or ".bias"
372+
parts = key.split('.')
373+
decoder_key_idx = parts.index('decoders') + 1
374+
src_tgt = parts[decoder_key_idx].split('->')
375+
src_layer = int(src_tgt[0])
376+
tgt_layer = int(src_tgt[1])
377+
param_type = parts[-1] # 'weight' or 'bias'
378+
379+
# Only use diagonal decoders (src->src) for tied architecture
380+
if src_layer == tgt_layer:
381+
new_key = '.'.join(parts[:decoder_key_idx] + [str(src_layer), param_type])
382+
new_state_dict[new_key] = value
383+
else:
384+
new_state_dict[key] = value
385+
state_dict = new_state_dict
386+
387+
# Handle feature affine parameters migration from encoder to decoder module
388+
# (for backward compatibility with old checkpoints)
389+
for i in range(self.config.num_layers):
390+
old_offset_key = f"encoder_module.feature_offset.{i}"
391+
new_offset_key = f"decoder_module.feature_offset.{i}"
392+
if old_offset_key in state_dict and new_offset_key not in state_dict:
393+
logger.info(f"Migrating {old_offset_key} to {new_offset_key}")
394+
state_dict[new_offset_key] = state_dict.pop(old_offset_key)
395+
396+
old_scale_key = f"encoder_module.feature_scale.{i}"
397+
new_scale_key = f"decoder_module.feature_scale.{i}"
398+
if old_scale_key in state_dict and new_scale_key not in state_dict:
399+
logger.info(f"Migrating {old_scale_key} to {new_scale_key}")
400+
state_dict[new_scale_key] = state_dict.pop(old_scale_key)
401+
402+
# Handle missing feature affine parameters (now in decoder module)
403+
if self.config.enable_feature_offset and hasattr(self.decoder_module, 'feature_offset') and self.decoder_module.feature_offset is not None:
404+
for i in range(self.config.num_layers):
405+
key = f"decoder_module.feature_offset.{i}"
406+
if key not in state_dict:
407+
logger.info(f"Initializing missing {key} to zeros")
408+
# Don't add to state_dict to let it be initialized by the module
409+
410+
if self.config.enable_feature_scale and hasattr(self.decoder_module, 'feature_scale') and self.decoder_module.feature_scale is not None:
411+
for i in range(self.config.num_layers):
412+
key = f"decoder_module.feature_scale.{i}"
413+
if key not in state_dict:
414+
logger.info(f"Initializing missing {key} (first target layer to ones, rest to zeros)")
415+
# Don't add to state_dict to let it be initialized by the module
416+
417+
# Handle missing skip weights
418+
if self.config.skip_connection and hasattr(self.decoder_module, 'skip_weights'):
419+
for i in range(self.config.num_layers):
420+
key = f"decoder_module.skip_weights.{i}"
421+
if key not in state_dict:
422+
logger.info(f"Initializing missing {key} to identity matrix")
423+
424+
# Call parent's load_state_dict
425+
return super().load_state_dict(state_dict, strict=strict)

0 commit comments

Comments
 (0)