-
Notifications
You must be signed in to change notification settings - Fork 81
Description
Hi folks,
I want to run a temporal model with a custom cell in an RNN like structure. The sequence length is around 500k univariate samples, and I want my RNN layer to return me the entire output sequence. Basically I am looking for the following kind of model:
struct RNNLayerSave{C} <: Lux.LuxCore.AbstractLuxWrapperLayer{:cell}
cell::C
end
function (rnn::RNNLayerSave)(x::AbstractArray{T,3}, ps, st) where {T}
memory = similar(x, 32, size(x, 2), size(x, 3))
(y, carry), st = Lux.apply(rnn.cell, x[:, 1, :], ps, st)
@trace for i in 2:size(x, 2)
(y, carry), st = Lux.apply(rnn.cell, (x[:, i, :], carry), ps, st)
memory[:, i, :] = y
end
return memory, st
endNow what I found is that the gradient computation scales quadratically with the sequence length. I ran the following benchmarking lines (CPU):
function enzyme_gradient(layer, data, ps, st)
Enzyme.gradient(Enzyme.Reverse, sum ∘ first ∘ Lux.apply, Const(layer), data, ps, Const(st))
end
@timeit to "Reactant compilation" compiled_gradient = @compile sync = true enzyme_gradient(model, x, ps, st)
@timeit to "First Julia call" g = compiled_gradient(model, x, ps, st);and obtained the following timings for different sequence lengths N
| N | Reactant compilation | First Julia call |
|---|---|---|
| 200 | 246 ms | 40.8 ms |
| 400 | 246 ms | 52.3 ms |
| 800 | 250 ms | 77.1 ms |
| 1600 | 268 ms | 227 ms |
| 3200 | 264 ms | 1.13 s |
| 6400 | 240 ms | 4.25 s |
| 12800 | 257 ms | 16.4 s |
With the sequence length of 500k in mind I think my blocker is clear. I also did not expect this computation to grow quadratically in sequence length. Maybe I am doing something suboptimal in terms of storing the sequence. I feel like I accidentally store it in such a way that a huge jacobian gets created. I would be very happy with some pointers here.
I also went through the code and saw the Recurrence structure with the return_sequence argument. Also tried this one, but it ran into an StackOverflow error rather quickly.