Skip to content

Commit d7ab9eb

Browse files
committed
A lot more sparse tests
1 parent ec8910c commit d7ab9eb

File tree

4 files changed

+280
-24
lines changed

4 files changed

+280
-24
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
134134
new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
135135
end
136136
end
137-
SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal))
138137
SparseArrays.nnz(x::JLSparseVector) = x.nnz
139138
SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
140139
SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
@@ -159,6 +158,7 @@ SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)...,
159158
JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
160159

161160
function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
161+
@boundscheck checkbounds(A, i, j)
162162
r1 = Int(@inbounds A.colPtr[j])
163163
r2 = Int(@inbounds A.colPtr[j+1]-1)
164164
(r1 > r2) && return zero(Tv)
@@ -186,9 +186,29 @@ function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
186186
return SparseMatrixCSC(transpose(x_transpose))
187187
end
188188

189+
JLSparseMatrixCSC(Mat::Union{Transpose{Tv, <:SparseMatrixCSC}, Adjoint{Tv, <:SparseMatrixCSC}}) where {Tv} = JLSparseMatrixCSC(JLSparseMatrixCSR(Mat))
190+
191+
function Base.size(g::JLSparseMatrixCSR, d::Integer)
192+
if 1 <= d <= 2
193+
return g.dims[d]
194+
elseif d > 1
195+
return 1
196+
else
197+
throw(ArgumentError("dimension must be ≥ 1, got $d"))
198+
end
199+
end
200+
201+
JLSparseMatrixCSR(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {Tv} =
202+
JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
203+
JLVector(parent(Mat).nzval), size(Mat))
204+
JLSparseMatrixCSR(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {Tv} =
205+
JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
206+
JLVector(conj.(parent(Mat).nzval)), size(Mat))
207+
189208
JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A
190209

191210
function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) where {Tv, Ti}
211+
@boundscheck checkbounds(A, i0, i1)
192212
c1 = Int(A.rowPtr[i0])
193213
c2 = Int(A.rowPtr[i0+1]-1)
194214
(c1 > c2) && return zero(Tv)
@@ -220,6 +240,25 @@ GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
220240
GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
221241
GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
222242

243+
Base.similar(Mat::JLSparseMatrixCSR) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
244+
Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat), T), size(Mat))
245+
246+
Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
247+
Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
248+
249+
Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
250+
Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
251+
252+
Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
253+
Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
254+
255+
Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
256+
Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
257+
258+
JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
259+
JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x)))
260+
JLArray(x::JLSparseMatrixCSR) = JLArray(collect(SparseMatrixCSC(x)))
261+
223262
# conversion of untyped data to a typed Array
224263
function typed_data(x::JLArray{T}) where {T}
225264
unsafe_wrap(Array, pointer(x), x.dims)
@@ -339,6 +378,7 @@ function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
339378
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
340379
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
341380
end
381+
JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
342382
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
343383
Base.size(x::JLSparseMatrixCSC) = x.dims
344384

@@ -352,12 +392,26 @@ function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
352392
copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
353393
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
354394
end
395+
JLSparseMatrixCSR(xs::SparseVector{Tv, Ti}) where {Ti, Tv} = JLSparseMatrixCSR(SparseMatrixCSC(xs))
355396
function JLSparseMatrixCSR(xs::JLSparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
356397
return JLSparseMatrixCSR(SparseMatrixCSC(xs))
357398
end
358399
function JLSparseMatrixCSC(xs::JLSparseMatrixCSR{Tv, Ti}) where {Ti, Tv}
359400
return JLSparseMatrixCSC(SparseMatrixCSC(xs))
360401
end
402+
function Base.copyto!(dst::JLSparseMatrixCSR, src::JLSparseMatrixCSR)
403+
if size(dst) != size(src)
404+
throw(ArgumentError("Inconsistent Sparse Matrix size"))
405+
end
406+
resize!(dst.rowPtr, length(src.rowPtr))
407+
resize!(dst.colVal, length(src.colVal))
408+
resize!(SparseArrays.nonzeros(dst), length(SparseArrays.nonzeros(src)))
409+
copyto!(dst.rowPtr, src.rowPtr)
410+
copyto!(dst.colVal, src.colVal)
411+
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
412+
dst.nnz = src.nnz
413+
dst
414+
end
361415
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
362416
Base.size(x::JLSparseMatrixCSR) = x.dims
363417

src/host/sparse.jl

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,36 @@ SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
1616
SparseArrays.nonzeros(g::T) where {T<:AbstractGPUSparseArray} = g.nzVal
1717

1818
SparseArrays.nonzeroinds(g::T) where {T<:AbstractGPUSparseVector} = g.iPtr
19-
SparseArrays.rowvals(g::T) where {T<:AbstractGPUSparseVector} = nonzeroinds(g)
19+
SparseArrays.rowvals(g::T) where {T<:AbstractGPUSparseVector} = SparseArrays.nonzeroinds(g)
2020

2121
SparseArrays.rowvals(g::AbstractGPUSparseMatrixCSC) = g.rowVal
2222
SparseArrays.getcolptr(S::AbstractGPUSparseMatrixCSC) = S.colPtr
2323

2424
Base.convert(T::Type{<:AbstractGPUSparseArray}, m::AbstractArray) = m isa T ? m : T(m)
2525

26+
# collect to Array
27+
Base.collect(x::AbstractGPUSparseVector) = collect(SparseVector(x))
28+
Base.collect(x::AbstractGPUSparseMatrixCSC) = collect(SparseMatrixCSC(x))
29+
Base.collect(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
30+
Base.collect(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
31+
Base.collect(x::AbstractGPUSparseMatrixCOO) = collect(SparseMatrixCSC(x))
32+
33+
Base.Array(x::AbstractGPUSparseVector) = collect(SparseVector(x))
34+
Base.Array(x::AbstractGPUSparseMatrixCSC) = collect(SparseMatrixCSC(x))
35+
Base.Array(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
36+
Base.Array(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
37+
Base.Array(x::AbstractGPUSparseMatrixCOO) = collect(SparseMatrixCSC(x))
38+
39+
SparseArrays.SparseVector(x::AbstractGPUSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x)))
40+
SparseArrays.SparseMatrixCSC(x::AbstractGPUSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(SparseArrays.getcolptr(x)), Array(SparseArrays.rowvals(x)), Array(SparseArrays.nonzeros(x)))
41+
42+
# similar
43+
Base.similar(Vec::V) where {V<:AbstractGPUSparseVector} = V(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec)), length(Vec))
44+
Base.similar(Mat::M) where {M<:AbstractGPUSparseMatrixCSC} = M(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat)), size(Mat))
45+
46+
Base.similar(Vec::V, T::Type) where {Tv, Ti, V<:AbstractGPUSparseVector{Tv, Ti}} = sparse_array_type(V){T, Ti}(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec), T), length(Vec))
47+
Base.similar(Mat::M, T::Type) where {M<:AbstractGPUSparseMatrixCSC} = sparse_array_type(M)(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat), T), size(Mat))
48+
2649
dense_array_type(sa::SparseVector) = SparseVector
2750
dense_array_type(::Type{SparseVector}) = SparseVector
2851
sparse_array_type(sa::SparseVector) = SparseVector
@@ -207,6 +230,52 @@ Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:s
207230
Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
208231
Base.getindex(A::AbstractGPUSparseMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
209232

233+
function Base.getindex(A::AbstractGPUSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
234+
@boundscheck checkbounds(A, i)
235+
ii = searchsortedfirst(SparseArrays.nonzeroinds(A), convert(Ti, i))
236+
(ii > SparseArrays.nnz(A) || SparseArrays.nonzeroinds(A)[ii] != i) && return zero(Tv)
237+
SparseArrays.nonzeros(A)[ii]
238+
end
239+
240+
function Base.getindex(A::AbstractGPUSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
241+
@boundscheck checkbounds(A, i0, i1)
242+
r1 = Int(SparseArrays.getcolptr(A)[i1])
243+
r2 = Int(SparseArrays.getcolptr(A)[i1+1]-1)
244+
(r1 > r2) && return zero(T)
245+
r1 = searchsortedfirst(SparseArrays.rowvals(A), i0, r1, r2, Base.Order.Forward)
246+
(r1 > r2 || SparseArrays.rowvals(A)[r1] != i0) && return zero(T)
247+
SparseArrays.nonzeros(A)[r1]
248+
end
249+
250+
## copying between sparse GPU arrays
251+
Base.copy(Vec::AbstractGPUSparseVector) = copyto!(similar(Vec), Vec)
252+
253+
function Base.copyto!(dst::AbstractGPUSparseVector, src::AbstractGPUSparseVector)
254+
if length(dst) != length(src)
255+
throw(ArgumentError("Inconsistent Sparse Vector size"))
256+
end
257+
resize!(SparseArrays.nonzeroinds(dst), length(SparseArrays.nonzeroinds(src)))
258+
resize!(SparseArrays.nonzeros(dst), length(SparseArrays.nonzeros(src)))
259+
copyto!(SparseArrays.nonzeroinds(dst), SparseArrays.nonzeroinds(src))
260+
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
261+
dst.nnz = src.nnz
262+
dst
263+
end
264+
265+
function Base.copyto!(dst::AbstractGPUSparseMatrixCSC, src::AbstractGPUSparseMatrixCSC)
266+
if size(dst) != size(src)
267+
throw(ArgumentError("Inconsistent Sparse Matrix size"))
268+
end
269+
resize!(SparseArrays.getcolptr(dst), length(SparseArrays.getcolptr(src)))
270+
resize!(SparseArrays.rowvals(dst), length(SparseArrays.rowvals(src)))
271+
resize!(SparseArrays.nonzeros(dst), length(SparseArrays.nonzeros(src)))
272+
copyto!(SparseArrays.getcolptr(dst), SparseArrays.getcolptr(src))
273+
copyto!(SparseArrays.rowvals(dst), SparseArrays.rowvals(src))
274+
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
275+
dst.nnz = src.nnz
276+
dst
277+
end
278+
210279
### BROADCAST
211280

212281
# broadcast container type promotion for combinations of sparse arrays and other types
@@ -749,12 +818,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
749818
offsets = rowPtr = sparse_arg.rowPtr
750819
colVal = similar(sparse_arg.colVal)
751820
nzVal = similar(sparse_arg.nzVal, Tv)
752-
output = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
821+
output = sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
753822
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
754823
offsets = colPtr = sparse_arg.colPtr
755824
rowVal = similar(sparse_arg.rowVal)
756825
nzVal = similar(sparse_arg.nzVal, Tv)
757-
output = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
826+
output = sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
758827
end
759828
else
760829
# determine the number of non-zero elements per row so that we can create an
@@ -803,15 +872,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
803872
output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC}
804873
ixVal = similar(offsets, Ti, total_nnz)
805874
nzVal = similar(offsets, Tv, total_nnz)
806-
output_sparse_typ = _sparse_array_type(sparse_typ)
875+
output_sparse_typ = sparse_array_type(sparse_typ)
807876
output_sparse_typ(offsets, ixVal, nzVal, size(bc))
808877
elseif sparse_typ <: AbstractGPUSparseVector && !fpreszeros
809878
val_array = bc.args[first(sparse_args)].nzVal
810879
similar(val_array, Tv, size(bc))
811880
elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros
812881
iPtr = similar(offsets, Ti, total_nnz)
813882
nzVal = similar(offsets, Tv, total_nnz)
814-
_sparse_array_type(sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
883+
sparse_array_type(sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
815884
end
816885
if sparse_typ <: AbstractGPUSparseVector && !fpreszeros
817886
nonsparse_args = map(bc.args) do arg
@@ -932,9 +1001,9 @@ function Base.mapreduce(f, op, A::AbstractGPUSparseMatrix; dims=:, init=nothing)
9321001
in(dims, [Colon(), 1, 2]) || error("only dims=:, dims=1 or dims=2 is supported")
9331002

9341003
if A isa AbstractGPUSparseMatrixCSR && dims == 1
935-
A = _csc_type(A)(A)
1004+
A = csc_type(A)(A)
9361005
elseif A isa AbstractGPUSparseMatrixCSC && dims == 2
937-
A = _csr_type(A)(A)
1006+
A = csr_type(A)(A)
9381007
end
9391008
m, n = size(A)
9401009
val_array = nonzeros(A)

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ include("testsuite.jl")
66
const init_code = quote
77
using Test, JLArrays, SparseArrays
88

9-
sparse_types(::Type{<:JLArray}) = (JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR)
10-
sparse_types(::Type{<:Array}) = (SparseVector, SparseMatrixCSC)
11-
129
include("testsuite.jl")
1310

11+
TestSuite.sparse_types(::Type{<:JLArray}) = (JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR)
12+
TestSuite.sparse_types(::Type{<:Array}) = (SparseVector, SparseMatrixCSC)
13+
1414
# Disable Float16-related tests until JuliaGPU/KernelAbstractions#600 is resolved
1515
if isdefined(JLArrays.KernelAbstractions, :POCL)
1616
TestSuite.supported_eltypes(::Type{<:JLArray}) =

0 commit comments

Comments
 (0)