Skip to content

[MRG] Optimization enhancements: CMA-ES algorithm and correlation loss function#1221

Open
ntolley wants to merge 21 commits intojonescompneurolab:masterfrom
ntolley:cma
Open

[MRG] Optimization enhancements: CMA-ES algorithm and correlation loss function#1221
ntolley wants to merge 21 commits intojonescompneurolab:masterfrom
ntolley:cma

Conversation

@ntolley
Copy link
Collaborator

@ntolley ntolley commented Jan 25, 2026

This PR makes a lot of changes to implement new optimization functions, namely:

Rather than trying to merge this PR in one go, I think it should be broken into smaller tasks as this required a decent amount of changes all over the code base. However there is value to have a fully functioning branch to refer back to and play around with.

@ntolley
Copy link
Collaborator Author

ntolley commented Jan 25, 2026

Here is some testing code to try it out, you will need to pip install cma for the code to run. I'm still working on improving the default optimization parameters and fixing random seeds to make everything reproducible.

With more intelligently chosen constraints you can do better, but honestly it's nice to see how well this approach does under such extreme conditions.

NOTE: This is a simulation hungry algorithm, the reason it works well is it runs many simulations in a batch on every epoch (100 by default). I'd recommend running this on an HPC with access to many cores (64 cores in my case).

import numpy as np
import matplotlib.pyplot as plt
from hnn_core import jones_2009_model, simulate_dipole, JoblibBackend
from hnn_core.network_models import add_erp_drives_to_jones_model
from hnn_core.optimization import Optimizer, add_opt_drives, set_params_opt_drives

tstop = 200
dt = 0.5  # Just used for testing, 0.025 is the default but will be slower
scaling_factor = 3000
smooth_win = 30

# Create target dipole
net_target = jones_2009_model()
add_erp_drives_to_jones_model(net_target)
target_dpl = simulate_dipole(net_target, tstop=tstop, dt=dt, verbose=False)
target_dpl = target_dpl[0].copy().smooth(smooth_win).scale(scaling_factor)

# Create base network with drives to be optimized
net_base = jones_2009_model()
net_base._verbose = False
constraints, initial_params = add_opt_drives(net_base, n_prox=2, n_dist=1)

# Run optimization
max_iter = 100
optim = Optimizer(net_base, tstop=tstop, constraints=constraints, solver='cma',
                set_params=set_params_opt_drives, initial_params=initial_params, max_iter=max_iter, obj_fun="dipole_corr")
                
popsize = 500  # number of simulations per epoch, bigger is often better but very expensive!
optim.fit(target=target_dpl, n_trials=1, scale_factor=scaling_factor,
        smooth_window_len=smooth_win, dt=dt, popsize=popsize)

# Simulate best parameters
with JoblibBackend(n_jobs=10):
    dpl_opt = simulate_dipole(optim.net_, tstop=tstop, dt=dt, n_trials=10)

opt_data = np.stack([dpl.copy().scale(scaling_factor).smooth(smooth_win).data['agg'] for dpl in dpl_opt])

target_std = np.std(target_dpl.data['agg'])
opt_std = np.std(opt_data)

opt_scaling =  (target_std / opt_std)
opt_data *= opt_scaling

# Get best results
fontsize = 14
ticksize = 10
plt.figure(figsize=(8, 3))
plt.subplot(1,2,1)
plt.plot(target_dpl.data['agg'] , color='k', label='Target')
plt.plot(opt_data.T, color='C2', linewidth=1, alpha=0.4)
plt.plot(np.mean(opt_data, axis=0), color='C2', label='Optimized')

plt.xlabel('Time (ms)', fontsize=fontsize)
plt.ylabel('Dipole (nAm)', fontsize=fontsize)
plt.xticks(fontsize=ticksize)
plt.yticks(fontsize=ticksize)
plt.legend(fontsize=10)

plt.subplot(1,2,2)
plt.plot(1 - np.array(optim.obj_))
plt.xlabel('Epochs', fontsize=fontsize)
plt.ylabel('Correlation', fontsize=fontsize)
plt.xticks(fontsize=ticksize)
plt.yticks(fontsize=ticksize)
plt.tight_layout()
plt.ylim(None, 1.01)
image

# "The current Network instance has external "
# + "drives, provide a Network object with no "
# + "external drives."
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but aren't you then just adding on new drives with each iteration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this instance, I found that it is easier to write a set_params function that directly modifies the properties of existing drives. Technically I write a workaround, but in any case I don't necessarily see why we should force users to make a function to either add brand new drives every time or modify existing ones.

f"Joblib will run {n_trials} trial(s) in parallel by "
f"distributing trials over {self.n_jobs} jobs."
)
if net._verbose:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A batch simulate test is failing that I'm at a loss to explain, it indicates that simulating serially is faster in parallel

The only real changes to batch simulate or parallel backends have to do with the verbose setting

Is it possible that these if blocks are slowing down parallel execution in a non-trivial way? But it doesn't make sense to me why serial would be faster...

"bounds": constraints,
"tolfun": obj_fun_kwargs.get("tolfun", 0.01),
"maxiter": max_iter,
"popsize": obj_fun_kwargs.get("popsize", 100),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the default value should be lower

# I think that's also why you had to add the max(0.01, ...) for sigma in set_params_opt_drives
# _b_obj_func = cma.BoundDomainTransform(_obj_func, constraints) # evaluates fun only in the bounded domain

sigma = 1 / (np.array(constraints[1]) - np.array(constraints[0]))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@katduecker I actually forgot I already implemented a sigma that is scaled by the size of the bounds here!

# I think that's also why you had to add the max(0.01, ...) for sigma in set_params_opt_drives
# _b_obj_func = cma.BoundDomainTransform(_obj_func, constraints) # evaluates fun only in the bounded domain

sigma = 0.25 * (np.array(constraints[1]) - np.array(constraints[0]))
Copy link
Collaborator Author

@ntolley ntolley Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@katduecker this is the right way to set it! the one earlier was wayyyy to small

the performance gain is quite impressive, you can use substantially smaller popsizes

asoplata added a commit to asoplata/hnn-core that referenced this pull request Mar 10, 2026
asoplata added a commit to asoplata/hnn-core that referenced this pull request Mar 10, 2026
This also removes ALL tests of the gen_opt.py file in order to continue
investigating the failing test in test_batch_simulate.py
asoplata added a commit to asoplata/hnn-core that referenced this pull request Mar 10, 2026
@ntolley
Copy link
Collaborator Author

ntolley commented Mar 10, 2026

@asoplata the it was indeed the high njobs that was causing the issues!

@asoplata
Copy link
Collaborator

Awesome, congrats! Yes I did my own secondary run of the same code here https://github.com/asoplata/hnn-core/actions/runs/22916219851 to test stochasticity and looks like you solved it! 😌 congrats!

return opt_params, obj, net_


def add_opt_drives(net, tstop=200, n_prox=2, n_dist=1):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asoplata @katduecker what do you think about removing these functions from the PR and including them in a textbook example in a different PR?

They're mostly convenience functions but as @katduecker brought up down below they need to be documented better

@ntolley ntolley changed the title WIP: Optimization enhancements: CMA-ES algorithm and correlation loss function [MRG] Optimization enhancements: CMA-ES algorithm and correlation loss function Mar 16, 2026
@ntolley ntolley requested review from asoplata and katduecker March 16, 2026 12:02
@ntolley
Copy link
Collaborator Author

ntolley commented Mar 16, 2026

@katduecker @asoplata @carolinafernandezp this PR is ready to review!

I'm thinking that we can add examples of use in a separate PR (perhaps on the textbook website?). The main goal here is just to make sure that the newly added solver="cma" and obj_fun="dipole_corr" work with the appropriate tests.

Important notes on testing:

  • initial_params serves no role in the "cma" solver so I've added a warning in case it is passed
  • I'm not sure how to approach testing the user defined objective function. While def test_user_obj_fun(solver): exists, it's mainly verifying it's possible to pass a function. I think what is actually necessary is a function like validate_obj_fun() which tells a user if their custom objective function works for all solvers (CMA requires the ability to batch simulate).

Related to making a future textbook example, it's unclear to me ifadd_opt_drives() and set_params_opt_drives() should be in the source code, as they are primarily convenience functions. Maybe just for a textbook example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Backlog

Development

Successfully merging this pull request may close these issues.

3 participants