@@ -14,20 +14,30 @@ void_setindex!(args...) = (setindex!(args...); return)
1414
1515const default_chunk_size = ForwardDiff. pickchunksize
1616
17- function ForwardColorJacCache (f,x,_chunksize = nothing ;
17+ function ForwardColorJacCache (f:: F ,x,_chunksize = nothing ;
1818 dx = nothing ,
1919 colorvec= 1 : length (x),
20- sparsity:: Union{AbstractArray,Nothing} = nothing )
20+ sparsity:: Union{AbstractArray,Nothing} = nothing ) where {F}
2121
2222 if _chunksize isa Nothing
2323 chunksize = ForwardDiff. pickchunksize (maximum (colorvec))
2424 else
2525 chunksize = _chunksize
2626 end
2727
28- p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
29- _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} .(vec (x),first (p))
30- t = ArrayInterface. restructure (x,_t)
28+ if x isa Array
29+ p = generate_chunked_partials (x,colorvec,chunksize)
30+ t = similar (x,Dual{typeof (ForwardDiff. Tag (f,eltype (vec (x))))})
31+ for i in eachindex (t)
32+ t[i] = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} (x[i],first (p)[1 ])
33+ end
34+ else
35+ p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
36+ _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x))))} .(vec (x),first (p))
37+ t = ArrayInterface. restructure (x,_t)
38+ end
39+
40+
3141 if dx isa Nothing
3242 fx = similar (t)
3343 _dx = similar (x)
@@ -46,13 +56,27 @@ function generate_chunked_partials(x,colorvec,::Val{chunksize}) where chunksize
4656 maxcolor = maximum (colorvec)
4757 num_of_chunks = Int (ceil (maxcolor / chunksize))
4858 padding_size = (chunksize - (maxcolor % chunksize)) % chunksize
49- partials = colorvec .== (1 : maxcolor)'
59+
60+ # partials = colorvec .== (1:maxcolor)'
61+ partials = BitMatrix (undef, length (colorvec), maxcolor)
62+ for i in 1 : maxcolor, j in 1 : length (colorvec)
63+ partials[j,i] = colorvec[j] == i
64+ end
65+
5066 padding_matrix = BitMatrix (undef, length (x), padding_size)
5167 partials = hcat (partials, padding_matrix)
5268
53- chunked_partials = map (i -> Tuple .(eachrow (partials[:,(i- 1 )* chunksize+ 1 : i* chunksize])),1 : num_of_chunks)
54- chunked_partials
5569
70+ # chunked_partials = map(i -> Tuple.(eachrow(partials[:,(i-1)*chunksize+1:i*chunksize])),1:num_of_chunks)
71+ chunked_partials = Vector {Vector{NTuple{chunksize,eltype(x)}}} (undef, num_of_chunks)
72+ for i in 1 : num_of_chunks
73+ tmp = Vector {NTuple{chunksize,eltype(x)}} (undef, size (partials,1 ))
74+ for j in 1 : size (partials,1 )
75+ tmp[j] = Tuple (@view partials[j,(i- 1 )* chunksize+ 1 : i* chunksize])
76+ end
77+ chunked_partials[i] = tmp
78+ end
79+ chunked_partials
5680end
5781
5882@inline function forwarddiff_color_jacobian (f,
@@ -280,11 +304,26 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
280304
281305 for i in eachindex (p)
282306 partial_i = p[i]
283- vect .= Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} .(vecx, partial_i)
307+
308+ if vect isa Array
309+ @inbounds @simd ivdep for j in eachindex (vect)
310+ vect[j] = Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} (vecx[j], partial_i[j])
311+ end
312+ else
313+ vect .= Dual {typeof(ForwardDiff.Tag(f,eltype(vecx)))} .(vecx, partial_i)
314+ end
315+
284316 f (fx,t)
285317 if ! (sparsity isa Nothing)
286318 for j in 1 : chunksize
287- dx .= partials .(fx, j)
319+
320+ if dx isa Array
321+ @inbounds @simd ivdep for k in eachindex (dx)
322+ dx[k] = partials (fx[k], j)
323+ end
324+ else
325+ dx .= partials .(fx, j)
326+ end
288327
289328 if ArrayInterface. fast_scalar_indexing (dx)
290329 # dx is implicitly used in vecdx
@@ -313,7 +352,13 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
313352 for j in 1 : chunksize
314353 col_index = (i- 1 )* chunksize + j
315354 (col_index > ncols) && return J
316- J[:, col_index] .= partials .(vecfx, j)
355+ if J isa Array
356+ @inbounds @simd for k in 1 : size (J,1 )
357+ J[k, col_index] = partials (vecfx[k], j)
358+ end
359+ else
360+ J[:, col_index] .= partials .(vecfx, j)
361+ end
317362 end
318363 end
319364 end
0 commit comments