1818
1919getsize (:: Val{N} ) where N = N
2020getsize (N:: Integer ) = N
21+ void_setindex! (args... ) = (setindex! (args... ); return )
2122
2223function ForwardColorJacCache (f,x,_chunksize = nothing ;
2324 dx = nothing ,
@@ -30,15 +31,15 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
3031 chunksize = _chunksize
3132 end
3233
33- p = adapt .(typeof (x),generate_chunked_partials (x,colorvec,chunksize))
34+ p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
3435 _t = Dual {ForwardDiff.Tag(f,eltype(vec(x)))} .(vec (x),first (p))
3536 t = ArrayInterface. restructure (x,_t)
3637 if dx isa Nothing
3738 fx = similar (t)
3839 _dx = similar (x)
3940 else
40- tup = first ( first (p) ) .* false
41- _pi = adapt .( typeof (dx),[tup for i in 1 : length (dx)])
41+ tup = ArrayInterface . allowed_getindex (ArrayInterface . allowed_getindex (p, 1 ), 1 ) .* false
42+ _pi = adapt ( parameterless_type (dx),[tup for i in 1 : length (dx)])
4243 fx = reshape (Dual {ForwardDiff.Tag(f,eltype(vec(x)))} .(vec (dx),_pi),size (dx)... )
4344 _dx = dx
4445 end
@@ -121,7 +122,7 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
121122 for j in 1 : chunksize
122123 col_index = (i- 1 )* chunksize + j
123124 (col_index > ncols) && return J
124- Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : zeros (nrows), hcat, 1 : ncols)
125+ Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt ( parameterless_type (J), zeros (eltype (J), nrows) ), hcat, 1 : ncols)
125126 J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
126127 end
127128 end
0 commit comments