Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 82 additions & 102 deletions src/ndt/analysis/activation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
- Geometry analysis (effective dimensionality, singular value distributions)
"""

from typing import Dict, List, Optional, Tuple, Any
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans, DBSCAN
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score

try:
import umap

HAS_UMAP = True
except ImportError:
HAS_UMAP = False
Expand Down Expand Up @@ -68,7 +68,7 @@ def pca_analysis(
self,
activations: np.ndarray,
n_components: Optional[int] = None,
return_transformed: bool = True
return_transformed: bool = True,
) -> Dict[str, Any]:
"""Perform PCA analysis on activations.

Expand Down Expand Up @@ -101,27 +101,23 @@ def pca_analysis(
n_95 = np.searchsorted(cumulative_var, 0.95) + 1

results = {
'explained_variance_ratio': pca.explained_variance_ratio_,
'cumulative_variance': cumulative_var,
'components': pca.components_,
'singular_values': pca.singular_values_,
'n_components_90': min(n_90, n_components),
'n_components_95': min(n_95, n_components),
'mean': pca.mean_,
'pca_object': pca
"explained_variance_ratio": pca.explained_variance_ratio_,
"cumulative_variance": cumulative_var,
"components": pca.components_,
"singular_values": pca.singular_values_,
"n_components_90": min(n_90, n_components),
"n_components_95": min(n_95, n_components),
"mean": pca.mean_,
"pca_object": pca,
}

if return_transformed:
results['transformed'] = transformed
results["transformed"] = transformed

return results

def cluster_analysis(
self,
activations: np.ndarray,
method: str = 'kmeans',
n_clusters: int = 5,
**kwargs
self, activations: np.ndarray, method: str = "kmeans", n_clusters: int = 5, **kwargs
) -> Dict[str, Any]:
"""Perform clustering analysis on activations.

Expand All @@ -142,51 +138,41 @@ def cluster_analysis(
- 'cluster_centers': Cluster centroids (kmeans only)
- 'inertia': Within-cluster sum of squares (kmeans only)
"""
if method == 'kmeans':
if method == "kmeans":
clusterer = KMeans(n_clusters=n_clusters, random_state=42, **kwargs)
labels = clusterer.fit_predict(activations)

results = {
'labels': labels,
'n_clusters': n_clusters,
'cluster_centers': clusterer.cluster_centers_,
'inertia': clusterer.inertia_
"labels": labels,
"n_clusters": n_clusters,
"cluster_centers": clusterer.cluster_centers_,
"inertia": clusterer.inertia_,
}

elif method == 'dbscan':
eps = kwargs.get('eps', 0.5)
min_samples = kwargs.get('min_samples', 5)
elif method == "dbscan":
eps = kwargs.get("eps", 0.5)
min_samples = kwargs.get("min_samples", 5)
clusterer = DBSCAN(eps=eps, min_samples=min_samples)
labels = clusterer.fit_predict(activations)

n_found = len(set(labels)) - (1 if -1 in labels else 0)
results = {
'labels': labels,
'n_clusters': n_found,
'n_noise': np.sum(labels == -1)
}
results = {"labels": labels, "n_clusters": n_found, "n_noise": np.sum(labels == -1)}
else:
raise ValueError(f"Unknown clustering method: {method}")

# Compute silhouette score if more than 1 cluster
unique_labels = set(labels)
if len(unique_labels) > 1 and -1 not in unique_labels:
results['silhouette_score'] = silhouette_score(activations, labels)
results["silhouette_score"] = silhouette_score(activations, labels)
elif len(unique_labels) > 2: # Has noise but also clusters
mask = labels != -1
if mask.sum() > 1:
results['silhouette_score'] = silhouette_score(
activations[mask], labels[mask]
)
results["silhouette_score"] = silhouette_score(activations[mask], labels[mask])

return results

def manifold_embedding(
self,
activations: np.ndarray,
method: str = 'tsne',
n_components: int = 2,
**kwargs
self, activations: np.ndarray, method: str = "tsne", n_components: int = 2, **kwargs
) -> Dict[str, Any]:
"""Compute low-dimensional manifold embedding.

Expand All @@ -205,49 +191,46 @@ def manifold_embedding(
- 'method': Method used
- 'params': Parameters used
"""
if method == 'tsne':
perplexity = kwargs.get('perplexity', min(30, activations.shape[0] - 1))
if method == "tsne":
perplexity = kwargs.get("perplexity", min(30, activations.shape[0] - 1))
tsne = TSNE(
n_components=n_components,
perplexity=perplexity,
random_state=42,
**{k: v for k, v in kwargs.items() if k != 'perplexity'}
**{k: v for k, v in kwargs.items() if k != "perplexity"},
)
embedding = tsne.fit_transform(activations)

results = {
'embedding': embedding,
'method': 'tsne',
'params': {'perplexity': perplexity, 'n_components': n_components}
"embedding": embedding,
"method": "tsne",
"params": {"perplexity": perplexity, "n_components": n_components},
}

elif method == 'umap':
elif method == "umap":
if not HAS_UMAP:
raise ImportError("UMAP not installed. Install with: pip install umap-learn")

n_neighbors = kwargs.get('n_neighbors', min(15, activations.shape[0] - 1))
n_neighbors = kwargs.get("n_neighbors", min(15, activations.shape[0] - 1))
reducer = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
random_state=42,
**{k: v for k, v in kwargs.items() if k != 'n_neighbors'}
**{k: v for k, v in kwargs.items() if k != "n_neighbors"},
)
embedding = reducer.fit_transform(activations)

results = {
'embedding': embedding,
'method': 'umap',
'params': {'n_neighbors': n_neighbors, 'n_components': n_components}
"embedding": embedding,
"method": "umap",
"params": {"n_neighbors": n_neighbors, "n_components": n_components},
}
else:
raise ValueError(f"Unknown embedding method: {method}")

return results

def singular_value_analysis(
self,
activations: np.ndarray
) -> Dict[str, Any]:
def singular_value_analysis(self, activations: np.ndarray) -> Dict[str, Any]:
"""Analyze singular value distribution of activations.

Computes detailed statistics about the singular value spectrum,
Expand All @@ -271,10 +254,10 @@ def singular_value_analysis(
S_norm = S / S.sum()

# Stable rank
stable_rank = (S ** 2).sum() / (S[0] ** 2) if S[0] > 0 else 0
stable_rank = (S**2).sum() / (S[0] ** 2) if S[0] > 0 else 0

# Participation ratio
participation_ratio = 1.0 / (S_norm ** 2).sum() if S_norm.sum() > 0 else 0
participation_ratio = 1.0 / (S_norm**2).sum() if S_norm.sum() > 0 else 0

# Spectral entropy
S_norm_nonzero = S_norm[S_norm > 0]
Expand All @@ -284,18 +267,16 @@ def singular_value_analysis(
condition_number = S[0] / S[-1] if S[-1] > 0 else np.inf

return {
'singular_values': S,
'normalized_sv': S_norm,
'stable_rank': stable_rank,
'participation_ratio': participation_ratio,
'spectral_entropy': entropy,
'condition_number': condition_number
"singular_values": S,
"normalized_sv": S_norm,
"stable_rank": stable_rank,
"participation_ratio": participation_ratio,
"spectral_entropy": entropy,
"condition_number": condition_number,
}

def compare_activations(
self,
activations_before: np.ndarray,
activations_after: np.ndarray
self, activations_before: np.ndarray, activations_after: np.ndarray
) -> Dict[str, Any]:
"""Compare activations before and after a critical moment (e.g., jump).

Expand Down Expand Up @@ -324,15 +305,19 @@ def compare_activations(

# Dimensionality change
dim_change = {
'stable_rank': sv_after['stable_rank'] - sv_before['stable_rank'],
'participation_ratio': sv_after['participation_ratio'] - sv_before['participation_ratio'],
'n_components_90': pca_after['n_components_90'] - pca_before['n_components_90']
"stable_rank": sv_after["stable_rank"] - sv_before["stable_rank"],
"participation_ratio": (
sv_after["participation_ratio"] - sv_before["participation_ratio"]
),
"n_components_90": (pca_after["n_components_90"] - pca_before["n_components_90"]),
}

# Subspace overlap (using principal angles)
n_components = min(10, min(pca_before['components'].shape[0], pca_after['components'].shape[0]))
V1 = pca_before['components'][:n_components].T
V2 = pca_after['components'][:n_components].T
n_components = min(
10, min(pca_before["components"].shape[0], pca_after["components"].shape[0])
)
V1 = pca_before["components"][:n_components].T
V2 = pca_after["components"][:n_components].T

# Compute principal angles via SVD
M = V1.T @ V2
Expand All @@ -341,19 +326,17 @@ def compare_activations(
subspace_overlap = np.cos(principal_angles).mean()

return {
'pca_before': pca_before,
'pca_after': pca_after,
'sv_before': sv_before,
'sv_after': sv_after,
'dim_change': dim_change,
'subspace_overlap': subspace_overlap,
'principal_angles': principal_angles
"pca_before": pca_before,
"pca_after": pca_after,
"sv_before": sv_before,
"sv_after": sv_after,
"dim_change": dim_change,
"subspace_overlap": subspace_overlap,
"principal_angles": principal_angles,
}

def neuron_importance(
self,
activations: np.ndarray,
method: str = 'variance'
self, activations: np.ndarray, method: str = "variance"
) -> Dict[str, Any]:
"""Compute importance scores for individual neurons.

Expand All @@ -370,11 +353,11 @@ def neuron_importance(
- 'dead_neurons': Indices of neurons with zero activation
- 'top_10_indices': Top 10 most important neurons
"""
if method == 'variance':
if method == "variance":
scores = np.var(activations, axis=0)
elif method == 'mean_abs':
elif method == "mean_abs":
scores = np.mean(np.abs(activations), axis=0)
elif method == 'sparsity':
elif method == "sparsity":
# Lower sparsity = higher importance
scores = 1.0 - np.mean(activations == 0, axis=0)
else:
Expand All @@ -384,18 +367,15 @@ def neuron_importance(
dead_neurons = np.where(np.all(activations == 0, axis=0))[0]

return {
'scores': scores,
'ranking': ranking,
'dead_neurons': dead_neurons,
'n_dead': len(dead_neurons),
'top_10_indices': ranking[:10],
'method': method
"scores": scores,
"ranking": ranking,
"dead_neurons": dead_neurons,
"n_dead": len(dead_neurons),
"top_10_indices": ranking[:10],
"method": method,
}

def activation_statistics(
self,
activations: np.ndarray
) -> Dict[str, float]:
def activation_statistics(self, activations: np.ndarray) -> Dict[str, float]:
"""Compute summary statistics for activations.

Args:
Expand All @@ -405,11 +385,11 @@ def activation_statistics(
Dictionary of statistics
"""
return {
'mean': float(np.mean(activations)),
'std': float(np.std(activations)),
'min': float(np.min(activations)),
'max': float(np.max(activations)),
'sparsity': float(np.mean(activations == 0)),
'n_samples': activations.shape[0],
'n_features': activations.shape[1]
"mean": float(np.mean(activations)),
"std": float(np.std(activations)),
"min": float(np.min(activations)),
"max": float(np.max(activations)),
"sparsity": float(np.mean(activations == 0)),
"n_samples": activations.shape[0],
"n_features": activations.shape[1],
}
Loading
Loading