Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit eec7232

Browse files
Handle CuSparse CSC case
1 parent 39f2ee6 commit eec7232

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
294294
sparsity = jac_cache.sparsity
295295
chunksize = jac_cache.chunksize
296296
color_i = 1
297+
adaptedcolorvec = adapt(__parameterless_type(typeof(dx)),colorvec)
298+
297299
maxcolor = maximum(colorvec)
298300

299301
if J isa AbstractSparseMatrix
@@ -357,7 +359,15 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
357359
+= means requires a zero'd out start
358360
=#
359361
if J isa AbstractSparseMatrix
360-
@. void_setindex!(Ref(nonzeros(J)),getindex(Ref(nonzeros(J)),rows_index) + (getindex(Ref(colorvec),cols_index) == color_i) * getindex(Ref(vecdx),rows_index),rows_index)
362+
if J isa SparseMatrixCSC
363+
@. void_setindex!(Ref(nonzeros(J)),getindex(Ref(nonzeros(J)),rows_index) + (getindex(Ref(adaptedcolorvec),cols_index) == color_i) * getindex(Ref(vecdx),rows_index),rows_index)
364+
else
365+
nzval = @view nonzeros(J)[rows_index]
366+
cv = @view adaptedcolorvec[cols_index]
367+
vdx = @view dx[rows_index]
368+
tmp = cv .== color_i
369+
nzval .+= tmp .* vdx
370+
end
361371
else
362372
@. void_setindex!(Ref(J),getindex(Ref(J),rows_index, cols_index) + (getindex(Ref(colorvec),cols_index) == color_i) * getindex(Ref(vecdx),rows_index),rows_index, cols_index)
363373
end

test/test_gpu_ad.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
using SparseDiffTools, CUDA, Test, LinearAlgebra
22
using ArrayInterface: allowed_getindex, allowed_setindex!
3+
using SparseArrays
4+
35
function f(dx,x)
46
dx[2:end-1] = x[1:end-2] - 2x[2:end-1] + x[3:end]
57
allowed_setindex!(dx,-2allowed_getindex(x,1) + allowed_getindex(x,2),1)
68
allowed_setindex!(dx,-2allowed_getindex(x,30) + allowed_getindex(x,29),30)
79
nothing
810
end
9-
11+
x = rand(4)
1012
_J1 = similar(rand(30,30))
1113
_denseJ1 = cu(collect(_J1))
1214
x = cu(rand(30))
1315
CUDA.allowscalar(false)
14-
forwarddiff_color_jacobian!(_denseJ1, f, x)
15-
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, sparsity = _J1) isa Nothing
16-
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _J1) isa Nothing
16+
_J2 = sparse(forwarddiff_color_jacobian!(_denseJ1, f, x))
17+
out = copy(_J2)
18+
forwarddiff_color_jacobian!(out, f, x, colorvec = repeat(1:3,10), sparsity = _J2)
19+
20+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, sparsity = cu(_J1)) isa Nothing
21+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = cu(_J1)) isa Nothing
1722
_Jt = similar(Tridiagonal(_J1))
1823
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _Jt) isa Nothing
24+
_Jt2 = similar(Tridiagonal(cu(_J1)))
25+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _Jt2) isa Nothing

0 commit comments

Comments
 (0)