You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+46Lines changed: 46 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,6 +11,10 @@ This library is intended for the training and analysis of cross-layer sparse cod
11
11
12
12
A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for `gpt2-small`). This implementation focuses on the core functionality needed to train and use CLTs, leveraging `nnsight` for model introspection and `datasets` for data handling.
13
13
14
+
The library now supports **tied decoders**, which can significantly reduce the number of parameters by sharing decoder weights across layers. Instead of training separate decoders for each (source, destination) pair, tied decoders use either:
15
+
-**Per-source tying**: One decoder per source layer, shared across all destination layers
16
+
-**Per-target tying**: One decoder per destination layer, shared across all source layers
17
+
14
18
Training a CLT involves the following steps:
15
19
1. Pre-generate activations with `scripts/generate_activations` (though an implementation of `StreamingActivationStore` is on the way).
16
20
2. Train a CLT (start with an expansion factor of at least `32`) using this data. Metrics can be logged to WandB. NMSE should get below `0.25`, or ideally even below `0.10`. As mentioned above, I recommend `BatchTopK` training, and suggest keeping `K` low--`200` is a good place to start.
@@ -85,6 +89,16 @@ Key configuration parameters are mapped to config classes via script arguments:
85
89
-`relu`: Standard ReLU activation.
86
90
-`batchtopk`: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted to `jumprelu`.
87
91
-`topk`: Selects top K features per token (row-wise top-k).
92
+
93
+
**Decoder Tying Options** (`--decoder-tying`):
94
+
-`none` (default): Traditional untied decoders - separate decoder for each (source, destination) layer pair
95
+
-`per_source`: Share decoder weights per source layer - each source layer has one decoder used for all destinations
96
+
-`per_target`: Share decoder weights per destination layer - each destination layer has one decoder that combines features from all source layers
- Uses `per_source` tying: 12 decoders instead of 78 for gpt2-small
184
+
- Enables feature scaling for better expressiveness
185
+
- Includes skip connections to preserve input information
186
+
- Uses BatchTopK with k=256 for training (can be converted to JumpReLU later)
187
+
142
188
### Multi-GPU Training (Tensor Parallelism)
143
189
144
190
This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (`torch.distributed`). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation.
0 commit comments