Skip to content

Commit 1de99b3

Browse files
author
Katharine Hyatt
committed
Sparse GPU array and broadcasting support
1 parent 2e884f7 commit 1de99b3

File tree

10 files changed

+1087
-2
lines changed

10 files changed

+1087
-2
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1414
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1515
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
16+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718

1819
[weakdeps]
@@ -33,5 +34,6 @@ Random = "1"
3334
Reexport = "1"
3435
ScopedValues = "1"
3536
Serialization = "1"
37+
SparseArrays = "1"
3638
Statistics = "1"
3739
julia = "1.10"

lib/JLArrays/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ version = "0.3.0"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
99
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1113

1214
[compat]
1315
Adapt = "2.0, 3.0, 4.0"
1416
GPUArrays = "11.1"
1517
KernelAbstractions = "0.9, 0.10"
18+
LinearAlgebra = "1"
1619
Random = "1"
20+
SparseArrays = "1"
1721
julia = "1.8"

lib/JLArrays/src/JLArrays.jl

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77
module JLArrays
88

9-
export JLArray, JLVector, JLMatrix, jl, JLBackend
9+
export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR
1010

1111
using GPUArrays
1212

1313
using Adapt
14+
using SparseArrays, LinearAlgebra
15+
16+
import GPUArrays: _dense_array_type
1417

1518
import KernelAbstractions
1619
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
@@ -115,7 +118,90 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
115118
end
116119
end
117120

121+
mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti}
122+
iPtr::JLArray{Ti, 1}
123+
nzVal::JLArray{Tv, 1}
124+
len::Int
125+
nnz::Ti
126+
127+
function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
128+
len::Integer) where {Tv, Ti <: Integer}
129+
new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
130+
end
131+
end
132+
SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal))
133+
SparseArrays.nnz(x::JLSparseVector) = x.nnz
134+
SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
135+
SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
136+
137+
mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
138+
colPtr::JLArray{Ti, 1}
139+
rowVal::JLArray{Ti, 1}
140+
nzVal::JLArray{Tv, 1}
141+
dims::NTuple{2,Int}
142+
nnz::Ti
143+
144+
function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
145+
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
146+
new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
147+
end
148+
end
149+
function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
150+
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
151+
end
152+
SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
153+
154+
JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
155+
156+
function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
157+
r1 = Int(@inbounds A.colPtr[j])
158+
r2 = Int(@inbounds A.colPtr[j+1]-1)
159+
(r1 > r2) && return zero(Tv)
160+
r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
161+
((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
162+
end
163+
164+
mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
165+
rowPtr::JLArray{Ti, 1}
166+
colVal::JLArray{Ti, 1}
167+
nzVal::JLArray{Tv, 1}
168+
dims::NTuple{2,Int}
169+
nnz::Ti
170+
171+
function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
172+
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
173+
new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
174+
end
175+
end
176+
function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
177+
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
178+
end
179+
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
180+
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
181+
return SparseMatrixCSC(transpose(x_transpose))
182+
end
183+
184+
JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A
185+
118186
GPUArrays.storage(a::JLArray) = a.data
187+
GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N}
188+
GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N}
189+
GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1}
190+
GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1}
191+
192+
GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC
193+
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC
194+
GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR
195+
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
196+
GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector
197+
GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
198+
199+
GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
200+
GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
201+
GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
202+
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
203+
GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
204+
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
119205

120206
# conversion of untyped data to a typed Array
121207
function typed_data(x::JLArray{T}) where {T}
@@ -217,6 +303,41 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
217303
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
218304
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
219305

306+
function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
307+
iPtr = JLVector{Ti}(undef, length(xs.nzind))
308+
nzVal = JLVector{Tv}(undef, length(xs.nzval))
309+
copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
310+
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
311+
return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
312+
end
313+
Base.length(x::JLSparseVector) = x.len
314+
Base.size(x::JLSparseVector) = (x.len,)
315+
316+
function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
317+
colPtr = JLVector{Ti}(undef, length(xs.colptr))
318+
rowVal = JLVector{Ti}(undef, length(xs.rowval))
319+
nzVal = JLVector{Tv}(undef, length(xs.nzval))
320+
copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
321+
copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
322+
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
323+
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
324+
end
325+
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
326+
Base.size(x::JLSparseMatrixCSC) = x.dims
327+
328+
function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
329+
csr_xs = SparseMatrixCSC(transpose(xs))
330+
rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
331+
colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
332+
nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
333+
copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
334+
copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
335+
copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
336+
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
337+
end
338+
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
339+
Base.size(x::JLSparseMatrixCSR) = x.dims
340+
220341
# idempotency
221342
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
222343

@@ -358,9 +479,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
358479
R
359480
end
360481

482+
Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
483+
GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
484+
Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
485+
GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
486+
Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
487+
GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
488+
361489
## KernelAbstractions interface
362490

363491
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
492+
KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
364493

365494
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
366495
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)

src/GPUArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using KernelAbstractions
1919

2020
# device functionality
2121
include("device/abstractarray.jl")
22+
include("device/sparse.jl")
2223

2324
# host abstractions
2425
include("host/abstractarray.jl")
@@ -34,6 +35,7 @@ include("host/random.jl")
3435
include("host/quirks.jl")
3536
include("host/uniformscaling.jl")
3637
include("host/statistics.jl")
38+
include("host/sparse.jl")
3739
include("host/alloc_cache.jl")
3840

3941

src/device/sparse.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# on-device sparse array types
2+
# should be excluded from coverage counts
3+
# COV_EXCL_START
4+
using SparseArrays
5+
6+
# NOTE: this functionality is currently very bare-bones, only defining the array types
7+
# without any device-compatible sparse array functionality
8+
9+
10+
# core types
11+
12+
export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
13+
GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
14+
15+
abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end
16+
17+
18+
struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
19+
iPtr::Vi
20+
nzVal::Vv
21+
len::Int
22+
nnz::Ti
23+
end
24+
25+
Base.length(g::GPUSparseDeviceVector) = g.len
26+
Base.size(g::GPUSparseDeviceVector) = (g.len,)
27+
SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
28+
SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
29+
SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal
30+
31+
struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
32+
colPtr::Vi
33+
rowVal::Vi
34+
nzVal::Vv
35+
dims::NTuple{2,Int}
36+
nnz::Ti
37+
end
38+
39+
SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
40+
SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
41+
SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
42+
43+
struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
44+
rowPtr::Vi
45+
colVal::Vi
46+
nzVal::Vv
47+
dims::NTuple{2, Int}
48+
nnz::Ti
49+
end
50+
51+
struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
52+
rowPtr::Vi
53+
colVal::Vi
54+
nzVal::Vv
55+
dims::NTuple{2,Int}
56+
blockDim::Ti
57+
dir::Char
58+
nnz::Ti
59+
end
60+
61+
struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
62+
rowInd::Vi
63+
colInd::Vi
64+
nzVal::Vv
65+
dims::NTuple{2,Int}
66+
nnz::Ti
67+
end
68+
69+
Base.length(g::AbstractGPUSparseDeviceMatrix) = prod(g.dims)
70+
Base.size(g::AbstractGPUSparseDeviceMatrix) = g.dims
71+
SparseArrays.nnz(g::AbstractGPUSparseDeviceMatrix) = g.nnz
72+
SparseArrays.getnzval(g::AbstractGPUSparseDeviceMatrix) = g.nzVal
73+
74+
struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M} <: AbstractSparseArray{Tv, Ti, N}
75+
rowPtr::Vi
76+
colVal::Vi
77+
nzVal::Vv
78+
dims::NTuple{N, Int}
79+
nnz::Ti
80+
end
81+
82+
function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
83+
@assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
84+
GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
85+
end
86+
87+
Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
88+
Base.size(g::GPUSparseDeviceArrayCSR) = g.dims
89+
SparseArrays.nnz(g::GPUSparseDeviceArrayCSR) = g.nnz
90+
SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal
91+
92+
# input/output
93+
94+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
95+
println(io, "$(length(A))-element device sparse vector at:")
96+
println(io, " iPtr: $(A.iPtr)")
97+
print(io, " nzVal: $(A.nzVal)")
98+
end
99+
100+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
101+
println(io, "$(length(A))-element device sparse matrix CSR at:")
102+
println(io, " rowPtr: $(A.rowPtr)")
103+
println(io, " colVal: $(A.colVal)")
104+
print(io, " nzVal: $(A.nzVal)")
105+
end
106+
107+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
108+
println(io, "$(length(A))-element device sparse matrix CSC at:")
109+
println(io, " colPtr: $(A.colPtr)")
110+
println(io, " rowVal: $(A.rowVal)")
111+
print(io, " nzVal: $(A.nzVal)")
112+
end
113+
114+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
115+
println(io, "$(length(A))-element device sparse matrix BSR at:")
116+
println(io, " rowPtr: $(A.rowPtr)")
117+
println(io, " colVal: $(A.colVal)")
118+
print(io, " nzVal: $(A.nzVal)")
119+
end
120+
121+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
122+
println(io, "$(length(A))-element device sparse matrix COO at:")
123+
println(io, " rowPtr: $(A.rowPtr)")
124+
println(io, " colInd: $(A.colInd)")
125+
print(io, " nzVal: $(A.nzVal)")
126+
end
127+
128+
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
129+
println(io, "$(length(A))-element device sparse array CSR at:")
130+
println(io, " rowPtr: $(A.rowPtr)")
131+
println(io, " colVal: $(A.colVal)")
132+
print(io, " nzVal: $(A.nzVal)")
133+
end
134+
135+
# COV_EXCL_STOP

0 commit comments

Comments
 (0)