Skip to content

Commit 14209cf

Browse files
committed
Update stopping criteria and progress bar
1 parent 6f16fd2 commit 14209cf

1 file changed

Lines changed: 105 additions & 49 deletions

File tree

src/scdef/models/_scdef.py

Lines changed: 105 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def _get_layer_init(init_value, layer_idx):
685685
self.local_params = [self.local_params[0]]
686686
self.global_params = [self.global_params[0]]
687687
# BRD
688-
m_brd = tfd.Gamma(100.0, 100.0 / self.brd_mean).sample(
688+
m_brd = tfd.Gamma(100.0, 100.0 / self.brd_mean * 0.1).sample(
689689
seed=rngs[0],
690690
sample_shape=[self.layer_sizes[0], 1],
691691
) # self.brd_mean
@@ -1485,13 +1485,20 @@ def fit(
14851485
self._has_fit = True
14861486
self._fit_revision += 1
14871487

1488-
def _compute_median_parents(self, local_params=None, global_params=None):
1488+
def _compute_median_parents(
1489+
self,
1490+
local_params=None,
1491+
global_params=None,
1492+
return_active_l0_count: bool = False,
1493+
):
14891494
"""Estimate median effective parent count using variational means.
14901495
14911496
Uses only the minimal variational parameters needed for the metric and
14921497
avoids materializing full posterior summaries.
14931498
"""
14941499
if self.n_layers < 2:
1500+
if return_active_l0_count:
1501+
return 1.0, 0
14951502
return 1.0
14961503

14971504
if local_params is None:
@@ -1524,6 +1531,8 @@ def _compute_median_parents(self, local_params=None, global_params=None):
15241531
active_1 = np.where(counts >= 1)[0]
15251532

15261533
if len(active_0) == 0 or len(active_1) == 0:
1534+
if return_active_l0_count:
1535+
return 1.0, int(len(active_0))
15271536
return 1.0
15281537

15291538
# L1 -> L0 weight posterior means
@@ -1535,7 +1544,10 @@ def _compute_median_parents(self, local_params=None, global_params=None):
15351544
H = -np.sum(p * np.log(p + 1e-12), axis=0)
15361545
n_parents = np.exp(H)
15371546

1538-
return np.median(n_parents)
1547+
median_parents = np.median(n_parents)
1548+
if return_active_l0_count:
1549+
return median_parents, int(len(active_0))
1550+
return median_parents
15391551

15401552
def _learn(
15411553
self,
@@ -1563,7 +1575,7 @@ def _learn(
15631575
target_parents=1.5,
15641576
max_elbo_drop=0.05,
15651577
alpha_max=None,
1566-
damping=0.5,
1578+
damping=1.0,
15671579
**kwargs,
15681580
):
15691581
if "n_epochs" in kwargs:
@@ -1771,7 +1783,10 @@ def data_stream():
17711783
alpha_trace_epochs = []
17721784
n_eff_parents_trace = []
17731785
trace_epoch = []
1786+
active_l0_factor_counts_trace = []
17741787
stop_message = None
1788+
interrupted = False
1789+
alpha_updates_enabled = True
17751790
if anneal_alpha and int(check_every) <= 0:
17761791
raise ValueError("check_every must be > 0 when anneal_alpha=True.")
17771792

@@ -1822,19 +1837,19 @@ def data_stream():
18221837
alpha_trace.append(float(self.alpha))
18231838
alpha_trace_epochs.append(int(total_epochs))
18241839

1825-
# --- Standard _learn behavior (no alpha annealing) ---
1826-
if not anneal_alpha and epochs_since_alpha_change >= min_epochs:
1840+
relative_improvement = np.nan
1841+
if epochs_since_alpha_change >= min_epochs:
18271842
if min_loss == np.inf:
18281843
min_loss = current_loss
1829-
18301844
relative_improvement = (min_loss - current_loss) / np.abs(min_loss)
18311845
min_loss = min(min_loss, current_loss)
1832-
18331846
if relative_improvement < tolerance:
18341847
early_stop_counter += 1
18351848
else:
18361849
early_stop_counter = 0
18371850

1851+
# --- Standard _learn behavior (no alpha annealing) ---
1852+
if not anneal_alpha and epochs_since_alpha_change >= min_epochs:
18381853
pbar.set_postfix(
18391854
{
18401855
"Loss": current_loss,
@@ -1849,58 +1864,89 @@ def data_stream():
18491864
break
18501865
# --- Alpha annealing mode: check every `check_every` epochs ---
18511866
elif anneal_alpha:
1852-
postfix = {"Loss": current_loss, "alpha": float(self.alpha)}
1867+
postfix = {"Loss": current_loss}
1868+
if epochs_since_alpha_change >= min_epochs:
1869+
postfix["Rel. impr."] = relative_improvement
1870+
postfix["alpha"] = float(self.alpha)
18531871
if len(n_eff_parents_trace) > 0:
18541872
postfix["n_eff"] = float(n_eff_parents_trace[-1])
1873+
if len(active_l0_factor_counts_trace) > 0:
1874+
postfix["active_L0"] = int(active_l0_factor_counts_trace[-1])
18551875
pbar.set_postfix(postfix)
1876+
1877+
if (
1878+
(not alpha_updates_enabled)
1879+
and epochs_since_alpha_change >= min_epochs
1880+
and early_stop_counter >= patience
1881+
):
1882+
stop_message = (
1883+
f"Converged at epoch {total_epochs}, alpha={self.alpha:.2f}"
1884+
)
1885+
break
1886+
18561887
if (total_epochs % int(check_every)) == 0:
1857-
median_parents = self._compute_median_parents(
1888+
median_parents, active_l0_count = self._compute_median_parents(
18581889
local_params=local_params,
18591890
global_params=global_params,
1891+
return_active_l0_count=True,
18601892
)
18611893
trace_epoch.append(int(total_epochs))
18621894
n_eff_parents_trace.append(float(median_parents))
1863-
if median_parents <= target_parents:
1864-
stop_message = (
1865-
"Stopping annealed learning: target reached at epoch "
1866-
f"{total_epochs} (n_eff_parents={float(median_parents):.3f} "
1867-
f"<= {float(target_parents):.3f}, alpha={float(self.alpha):.4f})."
1868-
)
1869-
break
1870-
1871-
if best_elbo is None:
1872-
best_elbo = current_loss
1873-
relative_drop = (current_loss - best_elbo) / abs(best_elbo)
1874-
if relative_drop > max_elbo_drop:
1875-
stop_message = (
1876-
"Stopping annealed learning: ELBO drop exceeded threshold "
1877-
f"at epoch {total_epochs} (drop={float(relative_drop):.4f} "
1878-
f"> {float(max_elbo_drop):.4f}, alpha={float(self.alpha):.4f})."
1879-
)
1880-
break
1881-
best_elbo = min(best_elbo, current_loss)
1882-
1883-
alpha_mult = jnp.asarray(
1884-
(median_parents / target_parents) ** damping,
1885-
dtype=alpha_jnp.dtype,
1895+
active_l0_factor_counts_trace.append(int(active_l0_count))
1896+
pbar.set_postfix(
1897+
{
1898+
"Loss": current_loss,
1899+
"Rel. impr.": relative_improvement,
1900+
"alpha": float(self.alpha),
1901+
"n_eff": float(median_parents),
1902+
"active_L0": int(active_l0_count),
1903+
}
18861904
)
1887-
alpha_jnp = alpha_jnp * alpha_mult
1888-
if alpha_max is not None and float(alpha_jnp) >= float(
1889-
alpha_max
1890-
):
1891-
alpha_jnp = jnp.minimum(
1892-
alpha_jnp,
1893-
jnp.asarray(alpha_max, dtype=alpha_jnp.dtype),
1905+
if alpha_updates_enabled and median_parents <= target_parents:
1906+
alpha_updates_enabled = False
1907+
# Continue optimization with fixed alpha, but reset
1908+
# optimizer momentum/state at the current variational params.
1909+
local_opt_state = local_optimizer.init(local_params)
1910+
global_opt_state = global_optimizer.init(global_params)
1911+
# Restart convergence tracking for the post-annealing phase.
1912+
epochs_since_alpha_change = 0
1913+
min_loss = np.inf
1914+
early_stop_counter = 0
1915+
1916+
if alpha_updates_enabled:
1917+
if best_elbo is None:
1918+
best_elbo = current_loss
1919+
relative_drop = (current_loss - best_elbo) / abs(best_elbo)
1920+
if relative_drop > max_elbo_drop:
1921+
stop_message = (
1922+
"Stopping annealed learning: ELBO drop exceeded threshold "
1923+
f"at epoch {total_epochs} (drop={float(relative_drop):.4f} "
1924+
f"> {float(max_elbo_drop):.4f}, alpha={float(self.alpha):.4f})."
1925+
)
1926+
break
1927+
best_elbo = min(best_elbo, current_loss)
1928+
1929+
alpha_mult = jnp.asarray(
1930+
(median_parents / target_parents) ** damping,
1931+
dtype=alpha_jnp.dtype,
18941932
)
1895-
self.alpha = float(alpha_jnp)
1896-
stop_message = (
1897-
"Stopping annealed learning: reached alpha_max at epoch "
1898-
f"{total_epochs} (alpha={float(self.alpha):.4f}, "
1899-
f"n_eff_parents={float(median_parents):.3f})."
1900-
)
1901-
break
1933+
alpha_jnp = alpha_jnp * alpha_mult
1934+
if alpha_max is not None and float(alpha_jnp) >= float(
1935+
alpha_max
1936+
):
1937+
alpha_jnp = jnp.minimum(
1938+
alpha_jnp,
1939+
jnp.asarray(alpha_max, dtype=alpha_jnp.dtype),
1940+
)
1941+
self.alpha = float(alpha_jnp)
1942+
stop_message = (
1943+
"Stopping annealed learning: reached alpha_max at epoch "
1944+
f"{total_epochs} (alpha={float(self.alpha):.4f}, "
1945+
f"n_eff_parents={float(median_parents):.3f})."
1946+
)
1947+
break
19021948

1903-
self.alpha = float(alpha_jnp)
1949+
self.alpha = float(alpha_jnp)
19041950
else:
19051951
pbar.set_postfix(
19061952
{
@@ -1909,11 +1955,12 @@ def data_stream():
19091955
)
19101956

19111957
except KeyboardInterrupt:
1958+
interrupted = True
19121959
self.logger.info("Interrupted. Exiting safely...")
19131960
finally:
19141961
pbar.close()
19151962

1916-
if stop_message is None:
1963+
if stop_message is None and not interrupted:
19171964
if anneal_alpha:
19181965
stop_message = (
19191966
"Stopping annealed learning: reached max epochs "
@@ -1941,6 +1988,15 @@ def data_stream():
19411988
self.adata.uns[
19421989
"n_eff_parents_trace_epochs"
19431990
] = self.n_eff_parents_trace_epochs.copy()
1991+
if anneal_alpha:
1992+
self.active_l0_factor_counts_trace = np.asarray(
1993+
active_l0_factor_counts_trace, dtype=int
1994+
)
1995+
self.adata.uns[
1996+
"active_l0_factor_counts_trace"
1997+
] = self.active_l0_factor_counts_trace.copy()
1998+
else:
1999+
self.adata.uns.pop("active_l0_factor_counts_trace", None)
19442000

19452001
self.set_posterior_means()
19462002
self.set_posterior_variances()

0 commit comments

Comments
 (0)