Skip to content

Gradient computation is RNN layer with stored outputs scales quadratically in sequence length #1560

@bartvanerp

Description

@bartvanerp

Hi folks,

I want to run a temporal model with a custom cell in an RNN like structure. The sequence length is around 500k univariate samples, and I want my RNN layer to return me the entire output sequence. Basically I am looking for the following kind of model:

struct RNNLayerSave{C} <: Lux.LuxCore.AbstractLuxWrapperLayer{:cell}
    cell::C
end
function (rnn::RNNLayerSave)(x::AbstractArray{T,3}, ps, st) where {T}
    memory = similar(x, 32, size(x, 2), size(x, 3))
    (y, carry), st = Lux.apply(rnn.cell, x[:, 1, :], ps, st)
    @trace for i in 2:size(x, 2)
        (y, carry), st = Lux.apply(rnn.cell, (x[:, i, :], carry), ps, st)
        memory[:, i, :] = y
    end
    return memory, st
end

Now what I found is that the gradient computation scales quadratically with the sequence length. I ran the following benchmarking lines (CPU):

function enzyme_gradient(layer, data, ps, st)
    Enzyme.gradient(Enzyme.Reverse, sum  first  Lux.apply, Const(layer), data, ps, Const(st))
end
@timeit to "Reactant compilation" compiled_gradient = @compile sync = true enzyme_gradient(model, x, ps, st)
@timeit to "First Julia call" g = compiled_gradient(model, x, ps, st);

and obtained the following timings for different sequence lengths N

N Reactant compilation First Julia call
200 246 ms 40.8 ms
400 246 ms 52.3 ms
800 250 ms 77.1 ms
1600 268 ms 227 ms
3200 264 ms 1.13 s
6400 240 ms 4.25 s
12800 257 ms 16.4 s

With the sequence length of 500k in mind I think my blocker is clear. I also did not expect this computation to grow quadratically in sequence length. Maybe I am doing something suboptimal in terms of storing the sequence. I feel like I accidentally store it in such a way that a huge jacobian gets created. I would be very happy with some pointers here.

I also went through the code and saw the Recurrence structure with the return_sequence argument. Also tried this one, but it ran into an StackOverflow error rather quickly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions