Skip to content

Latest commit

 

History

History
104 lines (84 loc) · 3.21 KB

File metadata and controls

104 lines (84 loc) · 3.21 KB

Final Solution for Candle Training

The Core Problem

Candle's VarBuilder is designed for inference only and returns immutable Tensor objects instead of trainable Var objects. This makes training impossible without workarounds.

The Solution: Generic Linear

Instead of modifying VarBuilder or fighting with Candle's internals, we use Rust's type system to create layers that work with both:

  • Tensor for inference (frozen weights)
  • Var for training (trainable parameters)

1. Generic Linear Layer

pub struct Linear<T> {
    pub weight: T,
    pub bias: Option<T>,
}

// Implementation for Tensor (inference/frozen)
impl Module for Linear<Tensor> {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let w = self.weight.t()?;
        let x = x.matmul(&w)?;
        match &self.bias {
            None => Ok(x),
            Some(bias) => x.broadcast_add(bias),
        }
    }
}

// Implementation for Var (training)
impl Module for Linear<Var> {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let w = self.weight.as_tensor().t()?;
        let x = x.matmul(&w)?;
        match &self.bias {
            None => Ok(x),
            Some(bias) => x.broadcast_add(bias.as_tensor()),
        }
    }
}

2. Loading Weights Without VarBuilder

// Load weights directly
let weights = candle::safetensors::load("model.safetensors", &device)?;

// For LoRA training - frozen base, trainable adapters
let base_weight = weights.get("layer.weight").unwrap();  // Tensor
let lora_down = Var::randn(0.0, 0.02, (rank, in_features), &device)?;  // Var
let lora_up = Var::zeros((out_features, rank), &device)?;  // Var

// Build layers
let base_layer = Linear { weight: base_weight.clone(), bias: None };
let lora_layer = LoRAAdapter { down: lora_down, up: lora_up, scale };

3. Why This Works

  1. Type Safety: Rust ensures we never mix frozen and trainable parameters
  2. No VarBuilder: We bypass it entirely by loading weights directly
  3. Clean Separation: Base model stays as Tensor, only LoRA is Var
  4. Zero Overhead: Generic dispatch is resolved at compile time

The make_var() Fix

In addition to the generic approach, we fixed a critical bug in Candle where Var::from_tensor() would fail on CUDA with computation graph tensors. The fix in tensor.rs:

pub(crate) fn make_var(&self) -> Result<Tensor> {
    match self.device() {
        Device::Cuda(_) => {
            // Detach and make contiguous first
            let safe_tensor = if self.is_contiguous() {
                self.detach()
            } else {
                self.contiguous()?
            };
            
            // Create fresh tensor to break graph connections
            // (double-copy approach to avoid CUDA kernel issues)
            // ... implementation details ...
        }
        _ => {
            // CPU path unchanged
        }
    }
}

This fix enables gradient checkpointing by allowing Var creation from intermediate activations.

Results

With this approach:

  • ✅ SDXL LoRA training works
  • ✅ SD3.5 LoRA training works
  • ✅ Flux LoRA training works
  • ✅ No VarBuilder modifications needed
  • ✅ Gradient checkpointing possible (with make_var fix)
  • ✅ Can train at 1024x1024 on 24GB GPUs