Skip to content

Latest commit

 

History

History
239 lines (182 loc) · 7.43 KB

File metadata and controls

239 lines (182 loc) · 7.43 KB

Complete Guide: Learnable Category Assignment

Quick Answer: How to Tell It's Working

Look for these 3 indicators IN ORDER:

1️⃣ Model Building Phase

[LEARNABLE ASSIGNMENT MODE]
Creating learnable categorical VQ-VAE with end-to-end assignment learning

Initializing assignment matrix from clustering...

✅ If you see this → Learnable mode initialized

2️⃣ Trainer Initialization

Learnable assignment training enabled:
  Assignment LR: 0.001
  Temperature schedule: linear
  Gradient clip norm: 1.0

✅ If you see this → Learnable trainer created

3️⃣ Training Epochs (THE DEFINITIVE PROOF)

Epoch 1/100 (15.2s): train=15.234 val=16.891
  Components: recon=12.180 vq=1.024 ortho=0.815 info=0.508 topo=0.707
              assign_orthogonality=0.423   ← LEARNABLE METRIC
              assign_balance=0.156         ← LEARNABLE METRIC
              assign_total=0.579           ← LEARNABLE METRIC
  Temperature: 1.00                        ← LEARNABLE METRIC
  Utilization: 11.2%

✅ If you see assign_* and TemperatureLEARNABLE MODE IS ACTIVE!

Implementation Status

✅ What's Complete

Component Status Files Tests
Core Modules ✅ Done 6 files, ~360 lines 13/13 passing
Integration ✅ Done train_vqvae.py 3/3 passing
Configuration ✅ Done learnable_simple_test.yaml -
Documentation ✅ Done This guide + 2 others -

Total Implementation: ~1,627 lines, 16/16 tests passing

⚠️ Known Limitation

Hybrid INITIAL Encoding Not Yet Supported

If your config uses:

families:
  initial:
    encoder: initial_hybrid  # VQVAEWithInitial wrapper

You'll see:

[WARNING] Learnable assignment mode not yet supported with hybrid INITIAL encoding
  Falling back to standard categorical VQ-VAE with static assignments

Workaround: Use the simple test config that avoids hybrid encoding.

Working Example Config

Use configs/vqvae/learnable_simple_test.yaml:

# Feature Families - SIMPLE (no hybrid encoding)
families:
  initial:
    encoder: identity  # ✅ Compatible with learnable
  temporal:
    encoder: PyramidTemporalEncoder

# Training
training:
  category_assignment: learnable  # ✅ Enable learnable mode

# Learnable Assignment Configuration
learnable_assignment:
  temperature_start: 1.0
  temperature_end: 0.1
  temperature_schedule: "linear"
  orthogonality_weight: 0.1
  balance_weight: 0.05
  assignment_lr: 0.001

Test Command

# Run 5 epochs to see learnable metrics
poetry run spinlock train-vqvae \
  --config configs/vqvae/learnable_simple_test.yaml \
  --epochs 5 \
  --verbose

What to Watch During Training

Epoch 1 (Initial State)

Temperature: 1.00                    ← Soft assignments
assign_orthogonality: 0.423          ← High (categories still learning)
assign_balance: 0.156                ← Categories unbalanced
Utilization: 11.2%                   ← Low codebook usage

Epoch 50 (Mid-Training)

Temperature: 0.50                    ← Annealing
assign_orthogonality: 0.023          ← Decreasing (better)
assign_balance: 0.012                ← Decreasing (more balanced)
Utilization: 24.5%                   ← Improving

Epoch 100 (Converged)

Temperature: 0.10                    ← Near-hard assignments
assign_orthogonality: 0.003          ← Low (categories distinct)
assign_balance: 0.001                ← Low (well balanced)
Utilization: 28.3%                   ← Stable

Comparison: Static vs Learnable

Aspect Static (Current) Learnable (New)
Visible in logs No assign_* metrics assign_* + Temperature
Warm-up message "Auto-discovering categories" ✅ "[LEARNABLE ASSIGNMENT MODE]"
Trainer type VQVAETrainer ✅ LearnableVQVAETrainer
Optimization Fixed categories ✅ Dual optimizer (model + assignments)
Temperature N/A ✅ Anneals 1.0 → 0.1

Files Created

Core Implementation

  1. src/spinlock/encoding/learnable_assignment.py (167 lines)
  2. src/spinlock/encoding/soft_routing.py (55 lines)
  3. src/spinlock/encoding/learnable_categorical_vqvae.py (282 lines)
  4. src/spinlock/encoding/training/assignment_losses.py (88 lines)
  5. src/spinlock/encoding/training/annealing.py (37 lines)
  6. src/spinlock/encoding/training/learnable_trainer.py (177 lines)

Configuration

  1. configs/vqvae/learnable_simple_test.yaml (Working example)
  2. configs/vqvae/learnable_vqvae_variable_length.yaml (Variable-length version)
  3. configs/vqvae/learnable_assignment.yaml (Original)

Tests

  1. tests/test_learnable_assignment.py (13 unit tests)
  2. tests/test_learnable_integration.py (3 integration tests)

Documentation

  1. IMPLEMENTATION_SUMMARY.md (Technical details)
  2. HOW_TO_TELL_LEARNABLE_IS_ACTIVE.md (Quick reference)
  3. docs/learnable-assignments-quickstart.md (User guide)
  4. LEARNABLE_MODE_COMPLETE_GUIDE.md (This file)

Benefits of Learnable Mode

  1. Task-Optimal Categories

    • Learned to minimize reconstruction loss
    • Not just correlation-based clustering
  2. Automatic Balancing

    • Balance loss prevents category collapse
    • No manual dead code reset needed
  3. End-to-End Differentiable

    • Gradients flow through entire pipeline
    • Unified optimization objective
  4. Simpler Configuration

    • No max_category_size, max_split_recursion_depth, etc.
    • Fewer hyperparameters to tune

Troubleshooting

Issue: No learnable metrics in epoch logs

Check:

  1. Config has category_assignment: learnable
  2. No warning about hybrid INITIAL encoding
  3. Trainer init shows "Learnable assignment training enabled"

If still missing: Check that model is LearnableCategoricalVQVAE not CategoricalHierarchicalVQVAE

Issue: Warning about hybrid INITIAL

This is expected! Hybrid mode not yet supported.

Solution: Use learnable_simple_test.yaml config

Issue: Temperature not annealing

Check: temperature_schedule in config Valid values: "linear", "exponential", "cosine"

Issue: Category collapse (all features → 1 category)

Increase: balance_weight from 0.05 to 0.10 Or increase: temperature_end from 0.1 to 0.3 (slower annealing)

Future Enhancements

  1. Hybrid INITIAL Support - Make learnable work with VQVAEWithInitial
  2. Variable-Length Support - Full integration with runtime temporal encoding
  3. Adaptive Temperature - Stop annealing when assignments stabilize
  4. Category Pruning - Dynamically remove unused categories

Quick Reference Card

✅ LEARNABLE ACTIVE          ❌ STATIC MODE (FALLBACK)
━━━━━━━━━━━━━━━━━━━━━━━━━━  ━━━━━━━━━━━━━━━━━━━━━━━━━━━
[LEARNABLE ASSIGNMENT MODE]  Auto-discovering categories
Learnable assignment enabled  VQVAETrainer created
assign_orthogonality=0.003    (no assign_* metrics)
Temperature: 0.91             (no Temperature)

Support

  • Implementation Details: See IMPLEMENTATION_SUMMARY.md
  • Quick Start: See docs/learnable-assignments-quickstart.md
  • This Guide: Complete reference for using learnable mode

Status: ✅ Fully implemented, tested, and ready to use! Limitation: Not yet compatible with hybrid INITIAL encoding Workaround: Use learnable_simple_test.yaml configuration