Skip to content

Commit 3533027

Browse files
ChrisRackauckas-ClaudeChrisRackauckasclaude
authored
Add ChainRulesCore support for SCCNonlinearProblem (#757)
This enables automatic differentiation through SCCNonlinearProblem solves by adding a ChainRulesCore extension that hooks into SciMLSensitivity's adjoint machinery. Changes: - Add `scc_solve_up` internal function that can be hooked by ChainRulesCore - Route public `solve` through `scc_solve_up` → `_scc_solve` - Add ChainRulesCore weak dependency and extension - Extension defines rrule for `scc_solve_up` that calls `_concrete_solve_adjoint` This is needed to fix zero gradients when differentiating through MTK v10 initialization which uses SCCNonlinearProblem. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: ChrisRackauckas <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 517270a commit 3533027

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

lib/SCCNonlinearSolve/Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1010
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1111
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1212

13+
[weakdeps]
14+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15+
16+
[extensions]
17+
SCCNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
18+
1319
[compat]
1420
Aqua = "0.8"
1521
BenchmarkTools = "1.5.0"
22+
ChainRulesCore = "1"
1623
CommonSolve = "0.2.4"
1724
ExplicitImports = "1.5"
1825
Hwloc = "3"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module SCCNonlinearSolveChainRulesCoreExt
2+
3+
using SCCNonlinearSolve
4+
using SCCNonlinearSolve: SCCAlg, scc_solve_up
5+
using SciMLBase: SCCNonlinearProblem, AbstractSensitivityAlgorithm, ChainRulesOriginator,
6+
_concrete_solve_adjoint
7+
8+
import ChainRulesCore
9+
10+
function ChainRulesCore.rrule(
11+
::typeof(scc_solve_up), prob::SCCNonlinearProblem,
12+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
13+
u0, p, alg::SCCAlg; kwargs...)
14+
_concrete_solve_adjoint(prob, alg, sensealg, u0, p, ChainRulesOriginator(); kwargs...)
15+
end
16+
17+
end

lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,30 @@ end
2323

2424
SCCAlg(; nlalg = nothing, linalg = nothing) = SCCAlg(nlalg, linalg)
2525

26-
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...)
27-
CommonSolve.solve(prob, SCCAlg(nothing, nothing); kwargs...)
26+
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem;
27+
sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
28+
CommonSolve.solve(prob, SCCAlg(nothing, nothing); sensealg, u0, p, kwargs...)
2829
end
2930

3031
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem,
31-
alg::SciMLBase.AbstractNonlinearAlgorithm; kwargs...)
32-
CommonSolve.solve(prob, SCCAlg(alg, nothing); kwargs...)
32+
alg::SciMLBase.AbstractNonlinearAlgorithm;
33+
sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
34+
CommonSolve.solve(prob, SCCAlg(alg, nothing); sensealg, u0, p, kwargs...)
35+
end
36+
37+
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg;
38+
sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
39+
u0 = u0 !== nothing ? u0 : prob.u0
40+
p = p !== nothing ? p : prob.p
41+
scc_solve_up(prob, sensealg, u0, p, alg; kwargs...)
42+
end
43+
44+
"""
45+
Internal solve function that can be hooked by ChainRulesCore for AD.
46+
"""
47+
function scc_solve_up(prob::SciMLBase.SCCNonlinearProblem, sensealg, u0, p, alg::SCCAlg;
48+
kwargs...)
49+
_scc_solve(prob, alg; kwargs...)
3350
end
3451

3552
probvec(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}) = prob.u0
@@ -60,7 +77,11 @@ function iteratively_build_sols(alg, sols, (prob, explicitfun), args...; kwargs.
6077
iteratively_build_sols(alg, (sols..., _sol), args...; kwargs...)
6178
end
6279

63-
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...)
80+
"""
81+
Internal solve implementation for SCCNonlinearProblem.
82+
This is called by scc_solve_up and should NOT be hooked by ChainRulesCore.
83+
"""
84+
function _scc_solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...)
6485
numscc = length(prob.probs)
6586
sols = iteratively_build_sols(
6687
alg, (), zip(prob.probs, prob.explicitfuns!)...; kwargs...)
@@ -79,4 +100,6 @@ function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwa
79100
SciMLBase.build_solution(prob, alg, u, resid; retcode, original = sols)
80101
end
81102

103+
export scc_solve_up
104+
82105
end

0 commit comments

Comments
 (0)