@@ -685,7 +685,7 @@ def _get_layer_init(init_value, layer_idx):
685685 self .local_params = [self .local_params [0 ]]
686686 self .global_params = [self .global_params [0 ]]
687687 # BRD
688- m_brd = tfd .Gamma (100.0 , 100.0 / self .brd_mean ).sample (
688+ m_brd = tfd .Gamma (100.0 , 100.0 / self .brd_mean * 0.1 ).sample (
689689 seed = rngs [0 ],
690690 sample_shape = [self .layer_sizes [0 ], 1 ],
691691 ) # self.brd_mean
@@ -1485,13 +1485,20 @@ def fit(
14851485 self ._has_fit = True
14861486 self ._fit_revision += 1
14871487
1488- def _compute_median_parents (self , local_params = None , global_params = None ):
1488+ def _compute_median_parents (
1489+ self ,
1490+ local_params = None ,
1491+ global_params = None ,
1492+ return_active_l0_count : bool = False ,
1493+ ):
14891494 """Estimate median effective parent count using variational means.
14901495
14911496 Uses only the minimal variational parameters needed for the metric and
14921497 avoids materializing full posterior summaries.
14931498 """
14941499 if self .n_layers < 2 :
1500+ if return_active_l0_count :
1501+ return 1.0 , 0
14951502 return 1.0
14961503
14971504 if local_params is None :
@@ -1524,6 +1531,8 @@ def _compute_median_parents(self, local_params=None, global_params=None):
15241531 active_1 = np .where (counts >= 1 )[0 ]
15251532
15261533 if len (active_0 ) == 0 or len (active_1 ) == 0 :
1534+ if return_active_l0_count :
1535+ return 1.0 , int (len (active_0 ))
15271536 return 1.0
15281537
15291538 # L1 -> L0 weight posterior means
@@ -1535,7 +1544,10 @@ def _compute_median_parents(self, local_params=None, global_params=None):
15351544 H = - np .sum (p * np .log (p + 1e-12 ), axis = 0 )
15361545 n_parents = np .exp (H )
15371546
1538- return np .median (n_parents )
1547+ median_parents = np .median (n_parents )
1548+ if return_active_l0_count :
1549+ return median_parents , int (len (active_0 ))
1550+ return median_parents
15391551
15401552 def _learn (
15411553 self ,
@@ -1563,7 +1575,7 @@ def _learn(
15631575 target_parents = 1.5 ,
15641576 max_elbo_drop = 0.05 ,
15651577 alpha_max = None ,
1566- damping = 0.5 ,
1578+ damping = 1.0 ,
15671579 ** kwargs ,
15681580 ):
15691581 if "n_epochs" in kwargs :
@@ -1771,7 +1783,10 @@ def data_stream():
17711783 alpha_trace_epochs = []
17721784 n_eff_parents_trace = []
17731785 trace_epoch = []
1786+ active_l0_factor_counts_trace = []
17741787 stop_message = None
1788+ interrupted = False
1789+ alpha_updates_enabled = True
17751790 if anneal_alpha and int (check_every ) <= 0 :
17761791 raise ValueError ("check_every must be > 0 when anneal_alpha=True." )
17771792
@@ -1822,19 +1837,19 @@ def data_stream():
18221837 alpha_trace .append (float (self .alpha ))
18231838 alpha_trace_epochs .append (int (total_epochs ))
18241839
1825- # --- Standard _learn behavior (no alpha annealing) ---
1826- if not anneal_alpha and epochs_since_alpha_change >= min_epochs :
1840+ relative_improvement = np . nan
1841+ if epochs_since_alpha_change >= min_epochs :
18271842 if min_loss == np .inf :
18281843 min_loss = current_loss
1829-
18301844 relative_improvement = (min_loss - current_loss ) / np .abs (min_loss )
18311845 min_loss = min (min_loss , current_loss )
1832-
18331846 if relative_improvement < tolerance :
18341847 early_stop_counter += 1
18351848 else :
18361849 early_stop_counter = 0
18371850
1851+ # --- Standard _learn behavior (no alpha annealing) ---
1852+ if not anneal_alpha and epochs_since_alpha_change >= min_epochs :
18381853 pbar .set_postfix (
18391854 {
18401855 "Loss" : current_loss ,
@@ -1849,58 +1864,89 @@ def data_stream():
18491864 break
18501865 # --- Alpha annealing mode: check every `check_every` epochs ---
18511866 elif anneal_alpha :
1852- postfix = {"Loss" : current_loss , "alpha" : float (self .alpha )}
1867+ postfix = {"Loss" : current_loss }
1868+ if epochs_since_alpha_change >= min_epochs :
1869+ postfix ["Rel. impr." ] = relative_improvement
1870+ postfix ["alpha" ] = float (self .alpha )
18531871 if len (n_eff_parents_trace ) > 0 :
18541872 postfix ["n_eff" ] = float (n_eff_parents_trace [- 1 ])
1873+ if len (active_l0_factor_counts_trace ) > 0 :
1874+ postfix ["active_L0" ] = int (active_l0_factor_counts_trace [- 1 ])
18551875 pbar .set_postfix (postfix )
1876+
1877+ if (
1878+ (not alpha_updates_enabled )
1879+ and epochs_since_alpha_change >= min_epochs
1880+ and early_stop_counter >= patience
1881+ ):
1882+ stop_message = (
1883+ f"Converged at epoch { total_epochs } , alpha={ self .alpha :.2f} "
1884+ )
1885+ break
1886+
18561887 if (total_epochs % int (check_every )) == 0 :
1857- median_parents = self ._compute_median_parents (
1888+ median_parents , active_l0_count = self ._compute_median_parents (
18581889 local_params = local_params ,
18591890 global_params = global_params ,
1891+ return_active_l0_count = True ,
18601892 )
18611893 trace_epoch .append (int (total_epochs ))
18621894 n_eff_parents_trace .append (float (median_parents ))
1863- if median_parents <= target_parents :
1864- stop_message = (
1865- "Stopping annealed learning: target reached at epoch "
1866- f"{ total_epochs } (n_eff_parents={ float (median_parents ):.3f} "
1867- f"<= { float (target_parents ):.3f} , alpha={ float (self .alpha ):.4f} )."
1868- )
1869- break
1870-
1871- if best_elbo is None :
1872- best_elbo = current_loss
1873- relative_drop = (current_loss - best_elbo ) / abs (best_elbo )
1874- if relative_drop > max_elbo_drop :
1875- stop_message = (
1876- "Stopping annealed learning: ELBO drop exceeded threshold "
1877- f"at epoch { total_epochs } (drop={ float (relative_drop ):.4f} "
1878- f"> { float (max_elbo_drop ):.4f} , alpha={ float (self .alpha ):.4f} )."
1879- )
1880- break
1881- best_elbo = min (best_elbo , current_loss )
1882-
1883- alpha_mult = jnp .asarray (
1884- (median_parents / target_parents ) ** damping ,
1885- dtype = alpha_jnp .dtype ,
1895+ active_l0_factor_counts_trace .append (int (active_l0_count ))
1896+ pbar .set_postfix (
1897+ {
1898+ "Loss" : current_loss ,
1899+ "Rel. impr." : relative_improvement ,
1900+ "alpha" : float (self .alpha ),
1901+ "n_eff" : float (median_parents ),
1902+ "active_L0" : int (active_l0_count ),
1903+ }
18861904 )
1887- alpha_jnp = alpha_jnp * alpha_mult
1888- if alpha_max is not None and float (alpha_jnp ) >= float (
1889- alpha_max
1890- ):
1891- alpha_jnp = jnp .minimum (
1892- alpha_jnp ,
1893- jnp .asarray (alpha_max , dtype = alpha_jnp .dtype ),
1905+ if alpha_updates_enabled and median_parents <= target_parents :
1906+ alpha_updates_enabled = False
1907+ # Continue optimization with fixed alpha, but reset
1908+ # optimizer momentum/state at the current variational params.
1909+ local_opt_state = local_optimizer .init (local_params )
1910+ global_opt_state = global_optimizer .init (global_params )
1911+ # Restart convergence tracking for the post-annealing phase.
1912+ epochs_since_alpha_change = 0
1913+ min_loss = np .inf
1914+ early_stop_counter = 0
1915+
1916+ if alpha_updates_enabled :
1917+ if best_elbo is None :
1918+ best_elbo = current_loss
1919+ relative_drop = (current_loss - best_elbo ) / abs (best_elbo )
1920+ if relative_drop > max_elbo_drop :
1921+ stop_message = (
1922+ "Stopping annealed learning: ELBO drop exceeded threshold "
1923+ f"at epoch { total_epochs } (drop={ float (relative_drop ):.4f} "
1924+ f"> { float (max_elbo_drop ):.4f} , alpha={ float (self .alpha ):.4f} )."
1925+ )
1926+ break
1927+ best_elbo = min (best_elbo , current_loss )
1928+
1929+ alpha_mult = jnp .asarray (
1930+ (median_parents / target_parents ) ** damping ,
1931+ dtype = alpha_jnp .dtype ,
18941932 )
1895- self .alpha = float (alpha_jnp )
1896- stop_message = (
1897- "Stopping annealed learning: reached alpha_max at epoch "
1898- f"{ total_epochs } (alpha={ float (self .alpha ):.4f} , "
1899- f"n_eff_parents={ float (median_parents ):.3f} )."
1900- )
1901- break
1933+ alpha_jnp = alpha_jnp * alpha_mult
1934+ if alpha_max is not None and float (alpha_jnp ) >= float (
1935+ alpha_max
1936+ ):
1937+ alpha_jnp = jnp .minimum (
1938+ alpha_jnp ,
1939+ jnp .asarray (alpha_max , dtype = alpha_jnp .dtype ),
1940+ )
1941+ self .alpha = float (alpha_jnp )
1942+ stop_message = (
1943+ "Stopping annealed learning: reached alpha_max at epoch "
1944+ f"{ total_epochs } (alpha={ float (self .alpha ):.4f} , "
1945+ f"n_eff_parents={ float (median_parents ):.3f} )."
1946+ )
1947+ break
19021948
1903- self .alpha = float (alpha_jnp )
1949+ self .alpha = float (alpha_jnp )
19041950 else :
19051951 pbar .set_postfix (
19061952 {
@@ -1909,11 +1955,12 @@ def data_stream():
19091955 )
19101956
19111957 except KeyboardInterrupt :
1958+ interrupted = True
19121959 self .logger .info ("Interrupted. Exiting safely..." )
19131960 finally :
19141961 pbar .close ()
19151962
1916- if stop_message is None :
1963+ if stop_message is None and not interrupted :
19171964 if anneal_alpha :
19181965 stop_message = (
19191966 "Stopping annealed learning: reached max epochs "
@@ -1941,6 +1988,15 @@ def data_stream():
19411988 self .adata .uns [
19421989 "n_eff_parents_trace_epochs"
19431990 ] = self .n_eff_parents_trace_epochs .copy ()
1991+ if anneal_alpha :
1992+ self .active_l0_factor_counts_trace = np .asarray (
1993+ active_l0_factor_counts_trace , dtype = int
1994+ )
1995+ self .adata .uns [
1996+ "active_l0_factor_counts_trace"
1997+ ] = self .active_l0_factor_counts_trace .copy ()
1998+ else :
1999+ self .adata .uns .pop ("active_l0_factor_counts_trace" , None )
19442000
19452001 self .set_posterior_means ()
19462002 self .set_posterior_variances ()
0 commit comments