@@ -13,20 +13,21 @@ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
1313import ADTypes: AutoZygote, AutoSparseZygote
1414
1515# # Satisfying High-Level Interface for Sparse Jacobians
16- function __gradient (:: Union{AutoSparseZygote, AutoZygote} , f, x, cols)
16+ function __gradient (:: Union{AutoSparseZygote, AutoZygote} , f:: F , x, cols) where {F}
1717 _, ∂x, _ = Zygote. gradient (__f̂, f, x, cols)
1818 return vec (∂x)
1919end
2020
21- function __gradient! (:: Union{AutoSparseZygote, AutoZygote} , f!, fx, x, cols)
21+ function __gradient! (:: Union{AutoSparseZygote, AutoZygote} , f!:: F , fx, x, cols) where {F}
2222 return error (" Zygote.jl cannot differentiate in-place (mutating) functions." )
2323end
2424
2525# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
2626# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
2727import Zygote: _jvec, _eyelike, _gradcopy!
2828
29- @views function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseZygote, AutoZygote} , f, x)
29+ @views function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseZygote, AutoZygote} , f:: F ,
30+ x) where {F}
3031 y, back = Zygote. pullback (_jvec ∘ f, x)
3132 δ = _eyelike (y)
3233 for k in LinearIndices (y)
@@ -36,13 +37,13 @@ import Zygote: _jvec, _eyelike, _gradcopy!
3637 return J
3738end
3839
39- function __jacobian! (J , :: Union{AutoSparseZygote, AutoZygote} , f!, fx, x)
40+ function __jacobian! (_ , :: Union{AutoSparseZygote, AutoZygote} , f!:: F , fx, x) where {F}
4041 return error (" Zygote.jl cannot differentiate in-place (mutating) functions." )
4142end
4243
4344# ## Jac, Hes products
4445
45- function numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
46+ function numback_hesvec! (dy, f:: F , x, v, cache1 = similar (v), cache2 = similar (v)) where {F}
4647 g = let f = f
4748 (dx, x) -> dx .= first (Zygote. gradient (f, x))
4849 end
@@ -57,15 +58,14 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
5758 @. dy = (cache1 - cache2) / (2 ϵ)
5859end
5960
60- function numback_hesvec (f, x, v)
61- g = x -> first (Zygote. gradient (f, x))
61+ function numback_hesvec (f:: F , x, v) where {F}
6262 T = eltype (x)
6363 # Should it be min? max? mean?
6464 ϵ = sqrt (eps (real (T))) * max (one (real (T)), abs (norm (x)))
6565 x += ϵ * v
66- gxp = g (x )
66+ gxp = first (Zygote . gradient (f, x) )
6767 x -= 2 ϵ * v
68- gxm = g (x )
68+ gxm = first (Zygote . gradient (f, x) )
6969 (gxp - gxm) / (2 ϵ)
7070end
7171
9494# # VecJac products
9595
9696# VJP methods
97- function auto_vecjac! (du, f, x, v)
97+ function auto_vecjac! (du, f:: F , x, v) where {F}
9898 ! static_hasmethod (f, typeof ((x,))) &&
9999 error (" For inplace function use autodiff = AutoFiniteDiff()" )
100100 du .= reshape (SparseDiffTools. auto_vecjac (f, x, v), size (du))
101101end
102102
103- function auto_vecjac (f, x, v)
103+ function auto_vecjac (f:: F , x, v) where {F}
104104 y, back = Zygote. pullback (f, x)
105- return vec (back (reshape (v, size (y)))[ 1 ] )
105+ return vec (only ( back (reshape (v, size (y)))) )
106106end
107107
108108# overload operator interface
109- function SparseDiffTools. _vecjac (f, u, autodiff:: AutoZygote )
110- cache = ()
109+ function SparseDiffTools. _vecjac (f:: F , _, u, autodiff:: AutoZygote ) where {F}
110+ ! static_hasmethod (f, typeof ((u,))) &&
111+ error (" For inplace function use autodiff = AutoFiniteDiff()" )
111112 pullback = Zygote. pullback (f, u)
112-
113- return AutoDiffVJP (f, u, cache, autodiff, pullback)
113+ return AutoDiffVJP (f, u, (), autodiff, pullback)
114114end
115115
116116function update_coefficients (L:: AutoDiffVJP{<:AutoZygote} , u, p, t; VJP_input = nothing )
117117 VJP_input != = nothing && (@set! L. u = VJP_input)
118-
119118 @set! L. f = update_coefficients (L. f, L. u, p, t)
120119 @set! L. pullback = Zygote. pullback (L. f, L. u)
120+ return L
121121end
122122
123123function update_coefficients! (L:: AutoDiffVJP{<:AutoZygote} , u, p, t; VJP_input = nothing )
124124 VJP_input != = nothing && copy! (L. u, VJP_input)
125-
126125 update_coefficients! (L. f, L. u, p, t)
127126 L. pullback = Zygote. pullback (L. f, L. u)
128-
129127 return L
130128end
131129
132130# Interpret the call as df/du' * v
133131function (L:: AutoDiffVJP{<:AutoZygote} )(v, p, t; VJP_input = nothing )
134132 # ignore VJP_input as pullback was computed in update_coefficients(...)
135133 y, back = L. pullback
136- V = reshape (v, size (y))
137-
138- return vec (first (back (V)))
134+ return vec (only (back (reshape (v, size (y)))))
139135end
140136
141137# prefer non in-place method
142- function (L:: AutoDiffVJP{<:AutoZygote, IIP, true} )(dv, v, p, t;
143- VJP_input = nothing ) where {IIP}
138+ function (L:: AutoDiffVJP{<:AutoZygote} )(dv, v, p, t; VJP_input = nothing )
144139 # ignore VJP_input as pullback was computed in update_coefficients!(...)
145-
146- _dv = L (v, p, t; VJP_input = VJP_input)
140+ _dv = L (v, p, t; VJP_input)
147141 copy! (dv, _dv)
148142end
149143
150- function (L:: AutoDiffVJP{<:AutoZygote, true, false} )(args... ; kwargs... )
151- error (" Zygote requires an out of place method with signature f(u)." )
152- end
153-
154144end # module
0 commit comments