Skip to content

Commit 4fd127c

Browse files
termi-officialChrisRackauckas
authored andcommitted
Fix some corner cases
1 parent ed5d83d commit 4fd127c

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function SciMLBase.__init(
198198

199199
has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing
200200
has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing
201-
has_forcing = alg.forcing !== missing && alg.forcing !== nothing
201+
has_forcing = alg.forcing !== missing && alg.forcing !== nothing && !(u isa Number) && !(J isa Diagonal)
202202

203203
if has_trustregion && has_linesearch
204204
error("TrustRegion and LineSearch methods are algorithmically incompatible.")
@@ -274,7 +274,7 @@ function InternalAPI.step!(
274274
end
275275
end
276276

277-
has_forcing = cache.forcing_cache !== nothing && cache.forcing_cache !== missing
277+
has_forcing = cache.forcing_cache !== nothing && cache.forcing_cache !== missing && !(cache.u isa Number) && !(J isa Diagonal)
278278

279279
if has_forcing
280280
pre_step_forcing!(cache.forcing_cache, cache.descent_cache, J, cache.u, cache.fu, cache.nsteps)

lib/NonlinearSolveFirstOrder/test/rootfind_tests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ end
99
using BenchmarkTools: @ballocated
1010
using StaticArrays: @SVector
1111

12-
u0s=([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
13-
1412
@testset for (concrete_jac, linsolve) in (
1513
(Val(false), KrylovJL_CG(; precs = nothing)),
1614
(Val(false), KrylovJL_GMRES(; precs = nothing)),
@@ -24,7 +22,7 @@ end
2422
),
2523
),
2624
)
27-
@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
25+
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0])
2826
solver = NewtonRaphson(; forcing=EisenstatWalkerForcing2(), linsolve, concrete_jac)
2927
sol = solve_oop(quadratic_f, u0; solver)
3028
@test SciMLBase.successful_retcode(sol)

0 commit comments

Comments
 (0)