Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 2beb289

Browse files
Merge pull request #181 from JuliaDiff/sparse_GPU_take_2
Sparse GPU take 2
2 parents c438420 + eec7232 commit 2beb289

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
294294
sparsity = jac_cache.sparsity
295295
chunksize = jac_cache.chunksize
296296
color_i = 1
297+
adaptedcolorvec = adapt(__parameterless_type(typeof(dx)),colorvec)
298+
297299
maxcolor = maximum(colorvec)
298300

299301
if J isa AbstractSparseMatrix
@@ -357,9 +359,17 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
357359
+= means requires a zero'd out start
358360
=#
359361
if J isa AbstractSparseMatrix
360-
@. setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((vecdx,),rows_index),rows_index)
362+
if J isa SparseMatrixCSC
363+
@. void_setindex!(Ref(nonzeros(J)),getindex(Ref(nonzeros(J)),rows_index) + (getindex(Ref(adaptedcolorvec),cols_index) == color_i) * getindex(Ref(vecdx),rows_index),rows_index)
364+
else
365+
nzval = @view nonzeros(J)[rows_index]
366+
cv = @view adaptedcolorvec[cols_index]
367+
vdx = @view dx[rows_index]
368+
tmp = cv .== color_i
369+
nzval .+= tmp .* vdx
370+
end
361371
else
362-
@. setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((vecdx,),rows_index),rows_index, cols_index)
372+
@. void_setindex!(Ref(J),getindex(Ref(J),rows_index, cols_index) + (getindex(Ref(colorvec),cols_index) == color_i) * getindex(Ref(vecdx),rows_index),rows_index, cols_index)
363373
end
364374
end
365375
color_i += 1

test/test_gpu_ad.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
using SparseDiffTools, CUDA, Test, LinearAlgebra
22
using ArrayInterface: allowed_getindex, allowed_setindex!
3+
using SparseArrays
4+
35
function f(dx,x)
46
dx[2:end-1] = x[1:end-2] - 2x[2:end-1] + x[3:end]
57
allowed_setindex!(dx,-2allowed_getindex(x,1) + allowed_getindex(x,2),1)
68
allowed_setindex!(dx,-2allowed_getindex(x,30) + allowed_getindex(x,29),30)
79
nothing
810
end
9-
11+
x = rand(4)
1012
_J1 = similar(rand(30,30))
1113
_denseJ1 = cu(collect(_J1))
1214
x = cu(rand(30))
1315
CUDA.allowscalar(false)
14-
forwarddiff_color_jacobian!(_denseJ1, f, x)
15-
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, sparsity = _J1) isa Nothing
16-
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _J1) isa Nothing
16+
_J2 = sparse(forwarddiff_color_jacobian!(_denseJ1, f, x))
17+
out = copy(_J2)
18+
forwarddiff_color_jacobian!(out, f, x, colorvec = repeat(1:3,10), sparsity = _J2)
19+
20+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, sparsity = cu(_J1)) isa Nothing
21+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = cu(_J1)) isa Nothing
1722
_Jt = similar(Tridiagonal(_J1))
1823
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _Jt) isa Nothing
24+
_Jt2 = similar(Tridiagonal(cu(_J1)))
25+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, colorvec = repeat(1:3,10), sparsity = _Jt2) isa Nothing

0 commit comments

Comments
 (0)