Skip to content

Commit 879c96d

Browse files
Use SciMLBase.set_mooncakeoriginator_if_mooncake for Mooncake AD support
NonlinearSolveBase was passing `SciMLBase.ChainRulesOriginator()` directly without wrapping it in `set_mooncakeoriginator_if_mooncake`. This caused the `originator` to remain as `ChainRulesOriginator` when using Mooncake AD, instead of being converted to `MooncakeOriginator`. This fix: - Removes the unused local `set_mooncakeoriginator_if_mooncake` definition from utils.jl - Updates solve.jl to wrap originator with `SciMLBase.set_mooncakeoriginator_if_mooncake` This enables proper Mooncake AD support detection in downstream packages like SciMLSensitivity. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 2e44205 commit 879c96d

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing,
9090
p,
9191
args...;
9292
alias = alias_spec,
93-
originator = SciMLBase.ChainRulesOriginator(),
93+
originator = SciMLBase.set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
9494
verbose,
9595
kwargs...))
9696
else
@@ -100,7 +100,7 @@ function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing,
100100
p,
101101
args...;
102102
alias = alias_spec,
103-
originator = SciMLBase.ChainRulesOriginator(),
103+
originator = SciMLBase.set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
104104
verbose,
105105
kwargs...)
106106
end

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,4 @@ function clean_sprint_struct(x, indent::Int)
320320
return "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))"
321321
end
322322

323-
set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x
324-
325323
end

0 commit comments

Comments
 (0)