Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 0d6f365

Browse files
Make inference's job much easier by avoiding map
1 parent fb09091 commit 0d6f365

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,22 @@ void_setindex!(args...) = (setindex!(args...); return)
1414

1515
const default_chunk_size = ForwardDiff.pickchunksize
1616

17-
function ForwardColorJacCache(f,x,_chunksize = nothing;
17+
function ForwardColorJacCache(f::F,x,_chunksize = nothing;
1818
dx = nothing,
1919
colorvec=1:length(x),
20-
sparsity::Union{AbstractArray,Nothing}=nothing)
20+
sparsity::Union{AbstractArray,Nothing}=nothing) where {F}
2121

2222
if _chunksize isa Nothing
2323
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
2424
else
2525
chunksize = _chunksize
2626
end
2727

28-
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
28+
if x isa Array
29+
p = generate_chunked_partials(x,colorvec,chunksize)
30+
else
31+
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
32+
end
2933
_t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p))
3034
t = ArrayInterface.restructure(x,_t)
3135
if dx isa Nothing
@@ -50,7 +54,10 @@ function generate_chunked_partials(x,colorvec,::Val{chunksize}) where chunksize
5054
padding_matrix = BitMatrix(undef, length(x), padding_size)
5155
partials = hcat(partials, padding_matrix)
5256

53-
chunked_partials = map(i -> Tuple.(eachrow(partials[:,(i-1)*chunksize+1:i*chunksize])),1:num_of_chunks)
57+
chunked_partials = Vector{Vector{NTuple{chunksize,eltype(x)}}}(undef, num_of_chunks)
58+
for i in 1:num_of_chunks
59+
chunked_partials[i] = Tuple.(eachrow(@view(partials[:,(i-1)*chunksize+1:i*chunksize])))
60+
end
5461
chunked_partials
5562

5663
end

0 commit comments

Comments
 (0)