Train a semantic segmentation model that detects Retrogressive Thaw Slumps (RTS) under extreme class imbalance (~0.1% positive pixels), optimising for high precision at acceptable recall to minimise false positives in the final pan-arctic map.
| Resource | Specification |
|---|---|
| Cloud | Google Cloud Platform |
| GPUs | A100 or H100 VM (multi-GPU spec TBD with PDG team) |
| Budget | $70,000 (training + inference combined) |
| Framework | PyTorch 2.x |
| IDE | VSCode + Remote-SSH (GCP VMs only — no Colab) |
| AI-assist | Claude Code |
| Dev/test | L4 VM (gpu-vm-l4) — cheaper, same Docker image |
| Setting | Value | Purpose |
|---|---|---|
| Random seed | 42 | Fixed for all stochastic processes |
| CUDNN deterministic | True | Reproducible convolution results |
| CUDNN benchmark | False | Disable auto-tuning for reproducibility |
| Python hash seed | 42 | Reproducible dictionary ordering |
Note: Deterministic mode may reduce training speed by 10-20%. For hyperparameter search, disable deterministic mode; enable for final runs.
| Component | Choice | Rationale |
|---|---|---|
| Architecture | UNet++ (via smp library) | Close to UNet3+ (v1 best performer); battle-tested library |
| Encoder backbone | EfficientNet-B5 | Strong ImageNet features, good capacity/memory balance for 512×512 |
| Pretrained weights | ImageNet | Transfer learning for faster convergence |
| Input size | 512×512×3 | RGB channels |
| Output | Binary segmentation mask | Single-class prediction |
EfficientNet variants B3/B7 may be tested for capacity trade-offs. B5 is the baseline default.
Experiment in priority order (stop when diminishing returns):
| Priority | Model | Notes |
|---|---|---|
| 1 (baseline) | UNet++ + EfficientNet-B5 (smp) | Proven architecture class; strong CNN baseline |
| 2 | SegFormer-B5 | Efficient Vision Transformer; strong on dense prediction tasks |
| 3 | DINOv3 encoder + dense head | Latest DINO self-supervised ViT; confirm model version at time of implementation |
SAM is not a direct fit for pixel-level semantic segmentation (prompt-based mask decoder). Skip unless UNet++ and SegFormer both fail to meet precision targets and a dedicated feasibility study is done. Skip Prithvi, SATMAE, SwinTransformer, Mask2Former unless experiments clearly plateau.
Fusion strategies should be tested in order of complexity:
| Order | Strategy | Description | When to Use |
|---|---|---|---|
| 1 | RGB baseline | No auxiliary data | Establish performance baseline |
| 2 | Individual channels | RGB + one auxiliary channel at a time | Identify which channels help |
| 3 | Early fusion | Channel stack (RGB + helpful auxiliaries → single encoder) | Simple, often sufficient |
| 4 | Late fusion | Separate encoders → feature-level fusion | Only if early fusion underperforms |
Critical: The same normalization statistics used during training must be used during inference. The inference pipeline loads normalization_stats.json from the model directory and applies identical normalization.
If 2025 imagery has significantly different radiometric properties than 2024 training data, this will manifest as degraded performance. Monitor inference predictions for systematic shifts.
Focal loss down-weights easy examples, focusing learning on hard cases. Particularly suited for class imbalance.
Formula: FL(p_t) = -α_t × (1 - p_t)^γ × log(p_t)
| Parameter | Baseline | Tuning Range | Effect |
|---|---|---|---|
| γ (gamma) | 2 | [1, 2, 3, 5] | Higher = more focus on hard examples |
| α (alpha) | 0.25 | [0.1, 0.25, 0.5, 0.75] | Weight for positive class |
Tversky loss allows explicit control over false positive vs false negative penalty.
Formula: TL = 1 - (TP + ε) / (TP + α×FN + β×FP + ε)
| Parameter | Range | Notes |
|---|---|---|
| α | [0.3, 0.5, 0.7] | Weight on false negatives |
| β | [0.3, 0.5, 0.7] | Weight on false positives |
For precision-focused training: Set β > α to penalize false positives more heavily.
Combine pixel-level and region-level objectives. Focal handles pixel-level calibration; Dice directly optimizes region overlap and is insensitive to the overwhelming number of true-negative pixels. This combination consistently outperforms single-component losses in segmentation benchmarks.
Formula: L = λ₁ × Focal(pred, target) + λ₂ × Dice(pred, target)
| Parameter | Baseline | Tuning Range |
|---|---|---|
| λ₁ (focal weight) | 1.0 | [0.5, 1.0, 2.0] |
| λ₂ (dice weight) | 1.0 | [0.5, 1.0, 2.0] |
Weight loss inversely proportional to class frequency. Options for computing weights:
- Linear: weight_rts = num_bg_pixels / num_rts_pixels
- Square root: weight_rts = sqrt(num_bg_pixels / num_rts_pixels)
- Log: weight_rts = log(num_bg_pixels / num_rts_pixels)
Label boundaries may be uncertain due to resolution mismatch or inherent ambiguity in RTS edges.
Both approaches will be implemented and selected via YAML config for ablation:
boundary_handling: none # options: none | ignore | soft_labels
boundary_ignore_width: 3 # pixels (used when boundary_handling: ignore)
soft_label_value: 0.05 # P(background near boundary) when boundary_handling: soft_labelsApproach 1: Ignore Regions (boundary_handling: ignore)
- Exclude pixels within
boundary_ignore_widthpixels of label boundaries from loss computation (set to ignore index 255) - Applied on-the-fly in the DataLoader using scipy binary dilation on label mask
- Simple, proven in medical imaging segmentation
Approach 2: Soft Labels (boundary_handling: soft_labels)
- Near-boundary pixels get softened labels: background →
soft_label_value, RTS →1 - soft_label_value - Options: constant soft values (0.05/0.95) or distance-based softening
- Requires using BCE with soft targets (not cross-entropy with integer labels)
Experiment order: Run baseline with none first; then ablate against soft_labels with a narrow band (1–2 pixels) and then a small soft value (~0.1).
| Metric | Formula | Use |
|---|---|---|
| IoU_RTS | TP / (TP + FP + FN) | Primary pixel-level metric |
| F1_RTS | 2TP / (2TP + FP + FN) | At operating threshold; for literature comparison |
Object-level evaluation treats each connected component as a detection instance.
| Metric | Description |
|---|---|
| Object Precision | Fraction of predicted objects that match ground truth |
| Object Recall | Fraction of ground truth objects that are detected |
| Object F1 | 2 × Obj_Precision × Obj_Recall / (Obj_Precision + Obj_Recall) — at operating threshold |
IoU Threshold for Matching:
| Threshold | Use Case | Recommendation |
|---|---|---|
| 0.5 | Standard (COCO default) | Requires good shape match |
| 0.3 | Relaxed | Preferred — approximate detections acceptable |
| 0.1 | Very relaxed | "Did we find something here?" |
Matching Algorithm: Greedy 1-to-1 matching:
- Threshold probability map → binary mask; extract connected components (blobs) for both prediction and ground truth
- Compute pairwise IoU for all (predicted blob, GT blob) pairs
- Sort predicted blobs by mean probability (highest first)
- Match each predicted blob to its highest-IoU GT blob, only if IoU ≥ threshold and that GT blob is unmatched
- Matched pairs → TP; unmatched predictions → FP; unmatched GT blobs → FN
Edge cases (expected to be rare given RTS morphology — noted for awareness, not implemented):
- One large prediction overlapping multiple GT objects → matched to the best-IoU GT; remaining GT blobs count as FN
- Multiple predictions overlapping one GT → only the first (highest confidence) matches; the rest count as FP
Threshold for in-training reporting: A fixed reference threshold of 0.5 is applied to extract connected components during training. This threshold is for monitoring trends across epochs, not for deployment — focal-loss outputs are not calibrated, so absolute object precision/recall values at 0.5 should be interpreted as relative to other epochs of the same run. The deployment threshold is selected post-training via the calibration procedure in §6.4 and §12.2.
| Metric | Formula | Use Case |
|---|---|---|
| PR-AUC | Area under precision-recall curve | Overall performance under imbalance |
Two approaches for selecting operating threshold:
| Approach | Description | Pros | Cons |
|---|---|---|---|
| Global threshold | Single threshold for all regions | Simple, consistent | May underperform in some regions |
| Region-specific thresholds | Calibrate per Arctic subregion | Adapts to regional characteristics | More complex, requires per-region validation data |
Recommendation: Start with global threshold. If post-inference analysis reveals systematic regional performance differences, consider region-specific thresholds.
Threshold calibration is run once, post-training, on the EMA-weight final model. In-training object metrics use the fixed 0.5 reference threshold (§6.2) to avoid a circular dependence between calibration and stopping decisions.
Real-world RTS prevalence is ~0.1-0.5%. With naive random sampling:
- Most batches contain zero or near-zero positive pixels
- Gradients dominated by easy negatives
- Model may collapse to "predict all background"
| Technique | Description | Effect |
|---|---|---|
| Balanced batch sampling | Each batch has ~50% positive tiles, ~50% negative tiles | Ensures model sees positives every batch |
| Focal loss | Down-weights easy examples | Focuses on hard cases |
| Curriculum learning | Gradually increase negative ratio during training | Prevents early collapse |
Concrete Schedule (based on 300 max epochs):
| Epoch Range | Pos:Neg Ratio | Rationale |
|---|---|---|
| 1–10 | 1:1 | Learn basic RTS features with maximum positive exposure |
| 11–30 | 1:5 | Introduce more negatives, start discriminating |
| 31–50 | 1:10 | Standard training ratio |
| 51–100 | 1:15 | Approaching realistic conditions |
| 101–300 | 1:20 | Near-realistic ratio for final refinement |
Implementation: Step-wise ratio changes at epoch boundaries (not interpolated). Ratio changes are applied at the epoch level (batch composition recalculated each epoch).
Early Stopping Note: With patience=20 on Val-Realistic, training will likely stop before epoch 300. The curriculum ensures the model has seen realistic ratios before convergence.
RTS range from ~50m to 2+ km. At 512×512 tiles with 3m resolution (~1.5km coverage):
- Small RTS (50-200m): Well captured within single tile
- Medium RTS (200m-1km): Well captured within single tile
- Large RTS (1-2+ km): Span multiple tiles, may never appear complete
Run inference at multiple effective resolutions to catch different RTS scales. See Inference Guide for detailed procedure.
| Scale | Effective Resolution | Field of View | Target RTS Size |
|---|---|---|---|
| 1.0 | 3m (native) | 1.5 km | Small to medium |
| 0.5 | 6m | 3 km | Medium to large |
Current recommendation: Train at native resolution only. Multi-resolution inference is sufficient for most cases.
Trigger for multi-resolution training: If post-inference analysis shows recall for large RTS (>1km) is significantly worse than small/medium RTS, consider adding downscaled training samples.
Model Configuration:
| Parameter | Value |
|---|---|
| Architecture | UNet++ (smp) |
| Backbone | EfficientNet-B5 |
| Pretrained weights | ImageNet |
| Input channels | 3 (RGB) |
| Input size | 512×512 |
Loss Configuration:
| Parameter | Value |
|---|---|
| Loss function | Focal |
| Gamma (γ) | 2 |
| Alpha (α) | 0.25 |
| boundary_handling | none |
Optimizer Configuration:
| Parameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning rate | 1e-4 |
| Weight decay | 1e-2 |
| Gradient clipping | Max norm 1.0 |
Learning Rate Schedule:
Phase 1 (frozen backbone) uses constant frozen_lr. Phase 2 (after unfreezing) uses cosine annealing with warmup.
| Parameter | Value | Applies to |
|---|---|---|
| Scheduler | Cosine annealing | Phase 2 only |
| Minimum LR | 1e-6 | Phase 2 only |
| Warmup epochs | 5 | Phase 2 only (epochs 11-15) |
| Warmup start LR | 1e-6 | Phase 2 only |
Backbone Freeze Strategy:
| Phase | Epochs | Backbone | Decoder | LR |
|---|---|---|---|---|
| Phase 1 | 1–freeze_epochs | Frozen | Training | frozen_lr |
| Phase 2 | freeze_epochs+ | Training | Training | base_lr (backbone: base_lr × backbone_lr_multiplier) |
All LR values configurable in YAML:
lr:
frozen_lr: 1e-3 # decoder-only phase (suggested default)
base_lr: 1e-4 # full fine-tuning base LR
backbone_lr_multiplier: 0.1 # backbone LR = base_lr × multiplier
freeze_backbone_epochs: 10 # number of epochs for Phase 1After unfreezing, backbone uses backbone_lr_multiplier × base_lr to prevent catastrophic forgetting.
EMA (Exponential Moving Average):
| Parameter | Value |
|---|---|
| Enabled | Yes |
| Decay | 0.999 |
| Used for validation | Yes (swap EMA in for validation, swap live weights back for training) |
EMA maintains a smoothed copy of model weights. Final model uses EMA weights.
Training Configuration:
| Parameter | Value |
|---|---|
| Mixed precision | FP16 (enabled) |
| Batch size (per GPU) | 32 |
| Effective batch size | 32 × n_gpus |
| Multi-GPU (DDP) | Not implemented initially; code structured to allow DDP addition later |
| Max epochs | 300 |
| Early stopping patience | 20 epochs |
| Early stopping metric | Val-Realistic PR-AUC at 1:200, moving average over last 3 validations |
| Early stopping min delta | 0.005 (placeholder; calibrate empirically — see §10.4) |
| Early stopping start epoch | 50 |
| Validation frequency | Every 5 epochs (configurable: val_frequency) |
Data Loading:
| Parameter | Value |
|---|---|
| Num workers (per GPU) | 8 |
| Pin memory | True |
| Prefetch factor | 2 |
| Persistent workers | True |
| Drop last batch | True |
Batch Sampling:
| Parameter | Value |
|---|---|
| Balanced sampling | Enabled |
| Positive fraction per batch | 0.5 |
| Training ratio (epoch-level) | Curriculum (see Section 7.3) |
Checkpointing:
| Parameter | Value |
|---|---|
| Save best metric | Val-Realistic PR-AUC |
| Save every N epochs | 10 |
| Keep last N checkpoints | 3 |
Reproducibility:
| Parameter | Value |
|---|---|
| Random seed | 42 |
| Deterministic mode | True |
| Seeds for final model | [42, 43, 44] |
Applied on-the-fly during training using Albumentations library.
Geometric Augmentations:
| Augmentation | Parameters | Probability |
|---|---|---|
| Random 90° rotation | — | 0.5 |
| Horizontal flip | — | 0.5 |
| Vertical flip | — | 0.5 |
| Shift-scale-rotate | shift=0.1, scale=0.2, rotate=45° | 0.5 |
| Elastic transform | alpha=120, sigma=6 | 0.3 |
| Affine transform | shear=(-10°, 10°) | 0.3 |
Color/Radiometric Augmentations (RGB channels only):
| Augmentation | Parameters | Probability |
|---|---|---|
| Brightness | ±0.2 | 0.5 |
| Contrast | ±0.2 | 0.5 |
| Saturation | ±0.2 | 0.5 |
| Gaussian noise | var_limit=(10, 50) | 0.3 |
| CLAHE | clip_limit=4.0, tile_grid_size=8×8 | 0.2 |
Multi-Scale Augmentation (applied to all channels):
| Augmentation | Parameters | Probability |
|---|---|---|
| RandomScale + PadIfNeeded/CenterCrop | scale=(0.5, 1.0), pad to 512×512 | 0.3 |
RandomScale simulates the effective resolution variation seen during multi-scale inference (0.5x scale). This reduces the train-inference scale gap by exposing the model to downscaled imagery during training.
Note: Color/radiometric augmentations apply only to RGB channels, not auxiliary bands in EXTRA dataset. Geometric and multi-scale augmentations apply to all channels and masks.
Data Preparation:
- Data validation checks pass
- Normalisation statistics computed and saved
- If
boundary_handling: ignore, boundary ignore masks created for all labels - Balanced batch sampler configured
- Spatial blocking verified (no geographic overlap between splits)
Create a standalone script check_data.py that iterates through the DataLoader (not just the files), to ensures that the augmentations, normalization, and tensor collating etc are actually working as expected. This is to prevent running expensive GPUs on bad data.
Environment:
- Docker container built and tested
- GPU memory profiled, batch size confirmed
- MLflow tracking server running
- Library versions pinned in requirements.txt
Configuration:
- Config file committed to version control
- Git commit hash recorded
- Baseline config validated
calibration
- Validation noise floor measured: Run validation 5× on same checkpoint with augmentation disabled; compute std of PR-AUC at 1:200; set early_stopping_min_delta = 2 × std in baseline config.
Phase 1: Backbone Frozen (Epochs 1–10)
- Freeze all encoder (backbone) parameters
- Train decoder with higher learning rate (1e-3)
- Purpose: Adapt decoder to RTS segmentation task without disturbing pretrained features
Phase 2: Full Fine-Tuning (Epochs 11+)
- Unfreeze backbone with lower learning rate (0.1× base LR)
- Apply curriculum learning schedule for negative ratio
- Update EMA weights after each optimizer step
- Validate on Val-Realistic every val_frequency epochs using EMA weights. Swap EMA weights into the model for the validation pass, then restore live weights before the next training step. All validation metrics, early-stopping decisions, and best-checkpoint comparisons use EMA weights. (configurable in YAML; suggested default 5)
- Check early stopping criterion on Val-Realistic PR-AUC,Early stopping is gated to begin at epoch 50, when the curriculum reaches near-realistic ratios; before this, validation runs and best-so-far checkpoints are saved but stopping is disabled
- Save checkpoint if best metric achieved
| Dataset | Ratios | Purpose | Tune On? |
|---|---|---|---|
| Val-Balanced | 1:1 | Quick sanity checks | No |
| Val-Realistic | 1:200, 1:500, 1:1000 | Early stopping, threshold calibration | Yes |
| Test-Realistic | 1:200, 1:500, 1:1000 | Final reporting | Never |
Efficient Multi-Ratio Evaluation: Run inference once on all validation samples, then subsample negatives to compute metrics at each ratio. No additional GPU time required.
- Confirm EMA weights for final model: validation already used EMA throughout training, so the final saved model is the EMA copy of the best-validation checkpoint. No metric change is expected at this step.
- Temperature scaling calibration: Learn temperature parameter T on Val-Realistic to calibrate prediction confidence
- Threshold selection: Using Val-Realistic, plot PR curves and select threshold where Precision ≥ target
- Test-Time Augmentation evaluation: Evaluate with and without TTA to quantify benefit
- Final evaluation: Report all metrics on Test-Realistic at all ratios (1:200, 1:500, 1:1000)
- Multi-seed runs: Train final configuration with seeds [42, 43, 44], report mean ± std
Monitor for these warning signs:
| Indicator | Sign | Remedy |
|---|---|---|
| Train-val divergence | Train loss decreasing while val loss increasing | Increase dropout, stronger augmentation |
| Large IoU gap | Train IoU > 0.9, Val IoU < 0.5 | Reduce model capacity, earlier stopping |
| Balanced vs realistic gap | Val-Balanced >> Val-Realistic | Model overfitting to balanced distribution |
TTA runs inference multiple times with different augmentations and averages predictions. Typically gives 1-3% IoU improvement at the cost of N× inference time.
| Setting | Transforms | Speed | Expected Gain |
|---|---|---|---|
| Minimal | Identity, horizontal flip | 2× slower | ~1% |
| Standard | Identity, hflip, vflip, rot180 | 4× slower | ~2% |
| Full | All 8 D4 symmetries | 8× slower | ~2-3% |
Recommendation: Use Minimal first, use Standard (4 transforms)only if necessary.
For each input image:
- Apply each transform (e.g., horizontal flip)
- Run model inference
- Apply inverse transform to prediction
- Average all predictions pixel-wise
- Apply threshold to averaged probabilities
Neural networks are often overconfident. Temperature scaling learns a single parameter T to calibrate probabilities.
Procedure:
- Freeze all model weights
- Compute logits on Val-Realistic
- Find T that minimizes negative log-likelihood
- Apply calibrated probabilities: P_calibrated = sigmoid(logits / T)
Typical T values range from 1.0 to 3.0.
For each prevalence ratio (1:200, 1:500, 1:1000):
- Compute precision-recall curve on Val-Realistic
- Find threshold achieving target precision (e.g., Precision ≥ 0.8)
- Record corresponding recall
- Document threshold and expected performance
Single-run results are noisy. For final model and key comparisons:
- Run with seeds [42, 43, 44]
- Report mean ± standard deviation
- Example format: IoU_RTS: 0.723 ± 0.012
Final results table should include:
| Metric | 1:200 | 1:500 | 1:1000 |
|---|---|---|---|
| IoU_RTS | X.XX ± X.XX | X.XX ± X.XX | X.XX ± X.XX |
| PR-AUC | X.XX ± X.XX | X.XX ± X.XX | X.XX ± X.XX |