Successfully implemented end-to-end learnable category assignments for VQ-VAE using Gumbel-Softmax, enabling gradient-based optimization of category assignments during training.
Status: ✅ All core modules implemented and tested
Files Created:
-
src/spinlock/encoding/learnable_assignment.py(~167 lines)SoftAssignmentMatrix: Learnable [D, K] assignment probabilitiesPerFamilyAssignmentMatrix: Per-family with block-diagonal constraintsinitialize_from_clustering: Clustering → high-confidence logits
-
src/spinlock/encoding/soft_routing.py(~55 lines)SoftGroupedFeatureExtractor: Weighted feature routing
-
src/spinlock/encoding/training/assignment_losses.py(~88 lines)soft_orthogonality_loss: Inter-category correlation penaltysoft_balance_loss: Category size variance penaltyfamily_constraint_loss: Cross-family assignment penalty
-
src/spinlock/encoding/training/annealing.py(~37 lines)TemperatureScheduler: Linear/exponential/cosine annealing
Tests: 13/13 passing (100%)
Files Created:
5. src/spinlock/encoding/learnable_categorical_vqvae.py (~282 lines)
LearnableAssignmentConfig: Configuration dataclassLearnableCategoricalVQVAE: Integrated model with soft assignments- Forward pass with temperature-controlled Gumbel-Softmax
- Assignment loss computation
- Freeze functionality (convert to hard GroupedFeatureExtractor)
Tests: 3/3 integration tests passing (100%)
Files Created:
6. src/spinlock/encoding/training/learnable_trainer.py (~169 lines)
LearnableVQVAETrainer: Extends VQVAETrainer- Dual optimizer (model + assignments)
- Temperature annealing per epoch
- Assignment loss integration
- Gradient clipping for stability
Files Modified:
7. src/spinlock/cli/train_vqvae.py (~150 lines modified)
- Added
--learnableCLI flag - Branching logic for learnable vs static mode
- Assignment matrix initialization from clustering
- Separate learnable config section
Files Created:
8. configs/vqvae/learnable_assignment.yaml (~150 lines)
- Complete configuration example
- Learnable assignment parameters
- Temperature annealing config
- Assignment loss weights
- Clustering initialization params
-
tests/test_learnable_assignment.py(~351 lines)- 13 unit tests for core modules
- Tests for soft assignments, routing, losses, annealing
- All passing
-
tests/test_learnable_integration.py(~165 lines)- 3 integration tests
- End-to-end forward/backward pass
- Freeze functionality
- All passing
Files Modified:
11. src/spinlock/encoding/__init__.py
- Exported LearnableCategoricalVQVAE and LearnableAssignmentConfig
- Created
LearnableCategoricalVQVAEinstead of modifying existingCategoricalHierarchicalVQVAE - Benefit: Clean separation, backward compatible, easier to test
- Separate optimizer for assignment matrix with lower LR (0.001 vs 0.01)
- Benefit: Assignments converge more slowly, prevents instability
- Still run initial clustering, but use as initialization (high-confidence logits)
- Benefit: Warm start accelerates convergence, provides reasonable initial categories
- Linear annealing from 1.0 → 0.1 over training epochs
- Benefit: Smooth transition from soft to hard assignments
- Hard constraint via block-diagonal assignment matrix
- Benefit: Zero probability for invalid cross-family assignments
New Files: 10 files, ~1,477 lines
- Core modules: ~347 lines
- Integrated model: ~282 lines
- Training integration: ~169 lines
- Tests: ~516 lines
- Config: ~150 lines
- CLI integration: ~150 lines (modifications)
Modified Files: 2 files, ~150 lines changed
train_vqvae.py: Branching logic + initialization__init__.py: Exports
Total Implementation: ~1,627 lines
✅ All tests passing (16/16)
test_learnable_assignment.py: 13/13 passing- SoftAssignmentMatrix: forward, temperature, hard assignments, gradients
- PerFamilyAssignmentMatrix: block-diagonal, no leakage
- SoftRouting: forward shape, gradients
- Assignment losses: balance, family constraints
- TemperatureScheduler: linear, exponential, cosine
test_learnable_integration.py: 3/3 passing- Full forward pass with soft assignments
- Backward pass with gradient flow
- Freeze assignments functionality
# Train with learnable assignments
poetry run spinlock train-vqvae \
--config configs/vqvae/learnable_assignment.yaml \
--learnable \
--verbose
# Or set in config
poetry run spinlock train-vqvae \
--config configs/vqvae/learnable_assignment.yaml \
--verbose# Enable learnable mode
training:
category_assignment: learnable
# Learnable assignment parameters
learnable_assignment:
temperature_start: 1.0
temperature_end: 0.1
temperature_schedule: "linear"
orthogonality_weight: 0.1
balance_weight: 0.05
family_constraint_weight: 1.0
assignment_lr: 0.001
freeze_after_epochs: null # Optional: freeze after N epochs-
Task-Optimal Categories
- Learned to minimize reconstruction loss, not just correlation
- Categories adapt to the specific modeling task
-
Automatic Balancing
- Balance loss prevents category collapse
- No manual dead code reset needed
-
Simpler Configuration
- No
max_category_size,max_split_recursion_depth, etc. - Fewer hyperparameters to tune
- No
-
End-to-End Differentiable
- Gradients flow through entire pipeline
- Unified optimization objective
-
Variable-Length Temporal Support
- Currently disabled for variable-length mode
- Would require runtime temporal encoding before soft routing
-
Hybrid INITIAL Encoding Support
- Currently falls back to static mode
- Would need integration with VQVAEWithInitial wrapper
-
Advanced Initialization Methods
- Currently only supports clustering initialization
- Could add: random, pretrained, user-specified
-
Curriculum Learning
- Gradual increase of assignment learning difficulty
- Could start with more categories, merge during training
-
Adaptive Temperature Scheduling
- Adjust temperature based on assignment convergence
- Stop annealing when assignments stabilize
-
Category Pruning
- Automatically remove unused categories during training
- Dynamic category count adjustment
-
Multi-Stage Training
- Train model with frozen assignments first
- Fine-tune assignments in second stage
- Alternate between model and assignment optimization
Design Principles:
- Single Responsibility: Each module <200 lines, does ONE thing
- Composition over Inheritance: Built from small, reusable pieces
- Explicit Data Flow: No hidden state, clear inputs/outputs
- Backward Compatible: Static assignment (default) + learnable (opt-in)
Related Work:
- Gumbel-Softmax: Jang et al. (2016) - Categorical Reparameterization with Gumbel-Softmax
- VQ-VAE: van den Oord et al. (2017) - Neural Discrete Representation Learning
- Temperature Annealing: Typical in discrete latent variable models
Successfully implemented a complete, well-tested learnable category assignment system for VQ-VAE. The implementation:
- ✅ Is backward compatible (static mode still works)
- ✅ Follows clean code principles (small, focused modules)
- ✅ Has comprehensive tests (100% passing)
- ✅ Includes example configuration
- ✅ Integrates smoothly with existing training pipeline
- ✅ Provides clear path for future enhancements
Total implementation time: ~4 hours (as estimated in plan)