3535"""
3636 VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff())
3737
38- Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`.
38+ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product
39+ `(df/du)ᵀ * v`.
3940
4041!!! note
4142
@@ -45,11 +46,11 @@ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `
4546```julia
4647L = VecJac(f, u)
4748
48- L * v # = df/du * v
49- mul!(w, L, v) # = df/du * v
49+ L * v # = ( df/du)ᵀ * v
50+ mul!(w, L, v) # = ( df/du)ᵀ * v
5051
51- L(v, p, t; VJP_input = w) # = df/dw * v
52- L(x, v, p, t; VJP_input = w) # = df/dw * v
52+ L(v, p, t; VJP_input = w) # = ( df/du)ᵀ * v
53+ L(x, v, p, t; VJP_input = w) # = ( df/du)ᵀ * v
5354```
5455
5556## Allowed Function Signatures for `f`
@@ -72,7 +73,7 @@ f(du, u) # Otherwise
7273"""
7374function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ; fu = nothing ,
7475 autodiff = AutoFiniteDiff (), kwargs... )
75- ff = VecJacFunctionWrapper (f, fu, u, p, t)
76+ ff = JacFunctionWrapper (f, fu, u, p, t)
7677
7778 if ! __internal_oop (ff) && autodiff isa AutoZygote
7879 msg = " Zygote requires an out of place method with signature f(u)."
@@ -83,82 +84,12 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
8384
8485 op = _vecjac (ff, fu, u, autodiff)
8586
86- # FIXME : FunctionOperator is terribly type unstable. It makes it `::Any`
8787 # NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from
88- # VecJacFunctionWrapper
89- return FunctionOperator (op, fu, u; p, t, isinplace = true , outofplace = true ,
88+ # JacFunctionWrapper
89+ return FunctionOperator (op, fu, u; p, t, isinplace = Val ( true ) , outofplace = Val ( true ) ,
9090 islinear = true , accepted_kwargs = (:VJP_input ,), kwargs... )
9191end
9292
93- mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
94- f:: F
95- fu:: FU
96- p:: P
97- t:: T
98- end
99-
100- function SciMLOperators. update_coefficients! (L:: VecJacFunctionWrapper{iip, oop, mode} , _,
101- p, t) where {iip, oop, mode}
102- mode == 1 && (L. t = t)
103- mode == 2 && (L. p = p)
104- return L
105- end
106- function SciMLOperators. update_coefficients (L:: VecJacFunctionWrapper{iip, oop, mode} , _, p,
107- t) where {iip, oop, mode}
108- return VecJacFunctionWrapper{iip, oop, mode, typeof (L. f), typeof (L. fu), typeof (p),
109- typeof (t)}(L. f, L. fu, p,
110- t)
111- end
112-
113- __internal_iip (:: VecJacFunctionWrapper{iip} ) where {iip} = iip
114- __internal_oop (:: VecJacFunctionWrapper{iip, oop} ) where {iip, oop} = oop
115-
116- (f:: VecJacFunctionWrapper{true, oop, 1} )(fu, u) where {oop} = f. f (fu, u, f. p, f. t)
117- (f:: VecJacFunctionWrapper{true, oop, 2} )(fu, u) where {oop} = f. f (fu, u, f. p)
118- (f:: VecJacFunctionWrapper{true, oop, 3} )(fu, u) where {oop} = f. f (fu, u)
119- (f:: VecJacFunctionWrapper{true, true, 1} )(u) = f. f (u, f. p, f. t)
120- (f:: VecJacFunctionWrapper{true, true, 2} )(u) = f. f (u, f. p)
121- (f:: VecJacFunctionWrapper{true, true, 3} )(u) = f. f (u)
122- (f:: VecJacFunctionWrapper{true, false, 1} )(u) = (f. f (f. fu, u, f. p, f. t); copy (f. fu))
123- (f:: VecJacFunctionWrapper{true, false, 2} )(u) = (f. f (f. fu, u, f. p); copy (f. fu))
124- (f:: VecJacFunctionWrapper{true, false, 3} )(u) = (f. f (f. fu, u); copy (f. fu))
125-
126- (f:: VecJacFunctionWrapper{false, true, 1} )(fu, u) = (vec (fu) .= vec (f. f (u, f. p, f. t)))
127- (f:: VecJacFunctionWrapper{false, true, 2} )(fu, u) = (vec (fu) .= vec (f. f (u, f. p)))
128- (f:: VecJacFunctionWrapper{false, true, 3} )(fu, u) = (vec (fu) .= vec (f. f (u)))
129- (f:: VecJacFunctionWrapper{false, true, 1} )(u) = f. f (u, f. p, f. t)
130- (f:: VecJacFunctionWrapper{false, true, 2} )(u) = f. f (u, f. p)
131- (f:: VecJacFunctionWrapper{false, true, 3} )(u) = f. f (u)
132-
133- function VecJacFunctionWrapper (f:: F , fu_, u, p, t) where {F}
134- fu = fu_ === nothing ? copy (u) : copy (fu_)
135- if t != = nothing
136- iip = static_hasmethod (f, typeof ((fu, u, p, t)))
137- oop = static_hasmethod (f, typeof ((u, p, t)))
138- if ! iip && ! oop
139- throw (ArgumentError (" `f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`" ))
140- end
141- return VecJacFunctionWrapper {iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)} (f,
142- fu, p, t)
143- elseif p != = nothing
144- iip = static_hasmethod (f, typeof ((fu, u, p)))
145- oop = static_hasmethod (f, typeof ((u, p)))
146- if ! iip && ! oop
147- throw (ArgumentError (" `f(u, p)` or `f(fu, u, p)` not defined for `f`" ))
148- end
149- return VecJacFunctionWrapper {iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)} (f,
150- fu, p, t)
151- else
152- iip = static_hasmethod (f, typeof ((fu, u)))
153- oop = static_hasmethod (f, typeof ((u,)))
154- if ! iip && ! oop
155- throw (ArgumentError (" `f(u)` or `f(fu, u)` not defined for `f`" ))
156- end
157- return VecJacFunctionWrapper {iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)} (f,
158- fu, p, t)
159- end
160- end
161-
16293function _vecjac (f:: F , fu, u, autodiff:: AutoFiniteDiff ) where {F}
16394 cache = (similar (fu), similar (fu))
16495 pullback = nothing
0 commit comments