@@ -14,18 +14,22 @@ void_setindex!(args...) = (setindex!(args...); return)
1414
1515const default_chunk_size = ForwardDiff. pickchunksize
1616
17- function ForwardColorJacCache (f,x,_chunksize = nothing ;
17+ function ForwardColorJacCache (f:: F ,x,_chunksize = nothing ;
1818 dx = nothing ,
1919 colorvec= 1 : length (x),
20- sparsity:: Union{AbstractArray,Nothing} = nothing )
20+ sparsity:: Union{AbstractArray,Nothing} = nothing ) where {F}
2121
2222 if _chunksize isa Nothing
2323 chunksize = ForwardDiff. pickchunksize (maximum (colorvec))
2424 else
2525 chunksize = _chunksize
2626 end
2727
28- p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
28+ if x isa Array
29+ p = generate_chunked_partials (x,colorvec,chunksize)
30+ else
31+ p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
32+ end
2933 _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} .(vec (x),first (p))
3034 t = ArrayInterface. restructure (x,_t)
3135 if dx isa Nothing
@@ -50,7 +54,10 @@ function generate_chunked_partials(x,colorvec,::Val{chunksize}) where chunksize
5054 padding_matrix = BitMatrix (undef, length (x), padding_size)
5155 partials = hcat (partials, padding_matrix)
5256
53- chunked_partials = map (i -> Tuple .(eachrow (partials[:,(i- 1 )* chunksize+ 1 : i* chunksize])),1 : num_of_chunks)
57+ chunked_partials = Vector {Vector{NTuple{chunksize,eltype(x)}}} (undef, num_of_chunks)
58+ for i in 1 : num_of_chunks
59+ chunked_partials[i] = Tuple .(eachrow (@view (partials[:,(i- 1 )* chunksize+ 1 : i* chunksize])))
60+ end
5461 chunked_partials
5562
5663end
0 commit comments