Skip to content

Commit 4379ec3

Browse files
authored
refactor: use Lux primitives for AD (#995)
* refactor: use Lux primitives for AD * fix: workaround SciML/Optimization.jl#848
1 parent a576e39 commit 4379ec3

File tree

3 files changed

+17
-40
lines changed

3 files changed

+17
-40
lines changed

examples/Basics/main.jl

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -214,50 +214,25 @@ f(x) = x .* x ./ 2
214214
x = randn(rng, Float32, 5)
215215
v = ones(Float32, 5)
216216

217-
# Construct the pushforward function. We will write out the function here but in
218-
# practice we recommend using
219-
# [SparseDiffTools.auto_jacvec](https://docs.sciml.ai/SparseDiffTools/stable/#Jacobian-Vector-and-Hessian-Vector-Products)!
220-
221-
# First we need to create a Tag for ForwardDiff. It is enough to know that this is something
222-
# that you must do. For more details, see the
223-
# [ForwardDiff Documentation](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Custom-tags-and-tag-checking)!
224-
struct TestTag end
225-
226-
# Going in the details of what is function is doing is beyond the scope of this tutorial.
227-
# But in short, it is constructing a new Dual Vector with the partials set to the input
228-
# to the pushforward function. When this is propagated through the original function
229-
# we get the value and the jvp
230-
function pushforward_forwarddiff(f, x)
231-
T = eltype(x)
232-
function pushforward(v)
233-
v_ = reshape(v, axes(x))
234-
y = ForwardDiff.Dual{
235-
ForwardDiff.Tag{TestTag, T}, T, 1}.(x, ForwardDiff.Partials.(tuple.(v_)))
236-
res = vec(f(y))
237-
return ForwardDiff.value.(res), vec(ForwardDiff.partials.(res, 1))
238-
end
239-
return pushforward
240-
end
241-
242-
pf_f = pushforward_forwarddiff(f, x)
217+
# !!! warning "Using DifferentiationInterface"
218+
#
219+
# While DifferentiationInterface provides these functions for a wider range of backends,
220+
# we currently don't recommend using them with Lux models, since the functions presented
221+
# here come with additional goodies like
222+
# [fast second-order derivatives](@ref nested_autodiff).
243223

244-
# Compute the jvp.
224+
# Compute the jvp. `AutoForwardDiff` specifies that we want to use `ForwardDiff.jl` for the
225+
# Jacobian-Vector Product
245226

246-
val, jvp = pf_f(v)
247-
println("Computed Value: f(", x, ") = ", val)
248-
println("JVP: ", jvp[1])
227+
jvp = jacobian_vector_product(f, AutoForwardDiff(), x, v)
228+
println("JVP: ", jvp)
249229

250230
# ### Vector-Jacobian Product
251231

252232
# Using the same function and inputs, let us compute the VJP.
253233

254-
val, pb_f = Zygote.pullback(f, x)
255-
256-
# Compute the vjp.
257-
258-
vjp = only(pb_f(v))
259-
println("Computed Value: f(", x, ") = ", val)
260-
println("VJP: ", vjp[1])
234+
vjp = vector_jacobian_product(f, AutoZygote(), x, v)
235+
println("VJP: ", vjp)
261236

262237
# ## Linear Regression
263238

examples/NeuralODE/main.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# # MNIST Classification using Neural ODEs
22

33
# To understand Neural ODEs, users should look up
4-
# [these lecture notes](https://book.sciml.ai/notes/11-Differentiable_Programming_and_Neural_Differential_Equations/). We recommend users to directly use
4+
# [these lecture notes](https://book.sciml.ai/notes/11-Differentiable_Programming_and_Neural_Differential_Equations/).
5+
# We recommend users to directly use
56
# [DiffEqFlux.jl](https://docs.sciml.ai/DiffEqFlux/stable/), instead of implementing
67
# Neural ODEs from scratch.
78

@@ -31,7 +32,8 @@ function loadmnist(batchsize, train_split)
3132
## Use DataLoader to automatically minibatch and shuffle the data
3233
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
3334
## Don't shuffle the test data
34-
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
35+
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)
36+
)
3537
end
3638

3739
# ## Define the Neural ODE Layer

examples/OptimizationIntegration/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function train_model(dataloader)
114114
opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
115115

116116
epochs = 25
117-
res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, maxiters=epochs)
117+
res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
118118

119119
## Let's finetune a bit with L-BFGS
120120
opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))

0 commit comments

Comments
 (0)