@@ -16,13 +16,36 @@ SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
1616SparseArrays. nonzeros (g:: T ) where {T<: AbstractGPUSparseArray } = g. nzVal
1717
1818SparseArrays. nonzeroinds (g:: T ) where {T<: AbstractGPUSparseVector } = g. iPtr
19- SparseArrays. rowvals (g:: T ) where {T<: AbstractGPUSparseVector } = nonzeroinds (g)
19+ SparseArrays. rowvals (g:: T ) where {T<: AbstractGPUSparseVector } = SparseArrays . nonzeroinds (g)
2020
2121SparseArrays. rowvals (g:: AbstractGPUSparseMatrixCSC ) = g. rowVal
2222SparseArrays. getcolptr (S:: AbstractGPUSparseMatrixCSC ) = S. colPtr
2323
2424Base. convert (T:: Type{<:AbstractGPUSparseArray} , m:: AbstractArray ) = m isa T ? m : T (m)
2525
26+ # collect to Array
27+ Base. collect (x:: AbstractGPUSparseVector ) = collect (SparseVector (x))
28+ Base. collect (x:: AbstractGPUSparseMatrixCSC ) = collect (SparseMatrixCSC (x))
29+ Base. collect (x:: AbstractGPUSparseMatrixCSR ) = collect (SparseMatrixCSC (x))
30+ Base. collect (x:: AbstractGPUSparseMatrixBSR ) = collect (SparseMatrixCSC (x))
31+ Base. collect (x:: AbstractGPUSparseMatrixCOO ) = collect (SparseMatrixCSC (x))
32+
33+ Base. Array (x:: AbstractGPUSparseVector ) = collect (SparseVector (x))
34+ Base. Array (x:: AbstractGPUSparseMatrixCSC ) = collect (SparseMatrixCSC (x))
35+ Base. Array (x:: AbstractGPUSparseMatrixCSR ) = collect (SparseMatrixCSC (x))
36+ Base. Array (x:: AbstractGPUSparseMatrixBSR ) = collect (SparseMatrixCSC (x))
37+ Base. Array (x:: AbstractGPUSparseMatrixCOO ) = collect (SparseMatrixCSC (x))
38+
39+ SparseArrays. SparseVector (x:: AbstractGPUSparseVector ) = SparseVector (length (x), Array (SparseArrays. nonzeroinds (x)), Array (SparseArrays. nonzeros (x)))
40+ SparseArrays. SparseMatrixCSC (x:: AbstractGPUSparseMatrixCSC ) = SparseMatrixCSC (size (x)... , Array (SparseArrays. getcolptr (x)), Array (SparseArrays. rowvals (x)), Array (SparseArrays. nonzeros (x)))
41+
42+ # similar
43+ Base. similar (Vec:: V ) where {V<: AbstractGPUSparseVector } = V (copy (SparseArrays. nonzeroinds (Vec)), similar (SparseArrays. nonzeros (Vec)), length (Vec))
44+ Base. similar (Mat:: M ) where {M<: AbstractGPUSparseMatrixCSC } = M (copy (SparseArrays. getcolptr (Mat)), copy (SparseArrays. rowvals (Mat)), similar (SparseArrays. nonzeros (Mat)), size (Mat))
45+
46+ Base. similar (Vec:: V , T:: Type ) where {Tv, Ti, V<: AbstractGPUSparseVector{Tv, Ti} } = sparse_array_type (V){T, Ti}(copy (SparseArrays. nonzeroinds (Vec)), similar (SparseArrays. nonzeros (Vec), T), length (Vec))
47+ Base. similar (Mat:: M , T:: Type ) where {M<: AbstractGPUSparseMatrixCSC } = sparse_array_type (M)(copy (SparseArrays. getcolptr (Mat)), copy (SparseArrays. rowvals (Mat)), similar (SparseArrays. nonzeros (Mat), T), size (Mat))
48+
2649dense_array_type (sa:: SparseVector ) = SparseVector
2750dense_array_type (:: Type{SparseVector} ) = SparseVector
2851sparse_array_type (sa:: SparseVector ) = SparseVector
@@ -207,6 +230,52 @@ Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:s
207230Base. getindex (A:: AbstractGPUSparseMatrix , :: Colon , i) = getindex (A, 1 : size (A, 1 ), i)
208231Base. getindex (A:: AbstractGPUSparseMatrix , I:: Tuple{Integer,Integer} ) = getindex (A, I[1 ], I[2 ])
209232
233+ function Base. getindex (A:: AbstractGPUSparseVector{Tv, Ti} , i:: Integer ) where {Tv, Ti}
234+ @boundscheck checkbounds (A, i)
235+ ii = searchsortedfirst (SparseArrays. nonzeroinds (A), convert (Ti, i))
236+ (ii > SparseArrays. nnz (A) || SparseArrays. nonzeroinds (A)[ii] != i) && return zero (Tv)
237+ SparseArrays. nonzeros (A)[ii]
238+ end
239+
240+ function Base. getindex (A:: AbstractGPUSparseMatrixCSC{T} , i0:: Integer , i1:: Integer ) where T
241+ @boundscheck checkbounds (A, i0, i1)
242+ r1 = Int (SparseArrays. getcolptr (A)[i1])
243+ r2 = Int (SparseArrays. getcolptr (A)[i1+ 1 ]- 1 )
244+ (r1 > r2) && return zero (T)
245+ r1 = searchsortedfirst (SparseArrays. rowvals (A), i0, r1, r2, Base. Order. Forward)
246+ (r1 > r2 || SparseArrays. rowvals (A)[r1] != i0) && return zero (T)
247+ SparseArrays. nonzeros (A)[r1]
248+ end
249+
250+ # # copying between sparse GPU arrays
251+ Base. copy (Vec:: AbstractGPUSparseVector ) = copyto! (similar (Vec), Vec)
252+
253+ function Base. copyto! (dst:: AbstractGPUSparseVector , src:: AbstractGPUSparseVector )
254+ if length (dst) != length (src)
255+ throw (ArgumentError (" Inconsistent Sparse Vector size" ))
256+ end
257+ resize! (SparseArrays. nonzeroinds (dst), length (SparseArrays. nonzeroinds (src)))
258+ resize! (SparseArrays. nonzeros (dst), length (SparseArrays. nonzeros (src)))
259+ copyto! (SparseArrays. nonzeroinds (dst), SparseArrays. nonzeroinds (src))
260+ copyto! (SparseArrays. nonzeros (dst), SparseArrays. nonzeros (src))
261+ dst. nnz = src. nnz
262+ dst
263+ end
264+
265+ function Base. copyto! (dst:: AbstractGPUSparseMatrixCSC , src:: AbstractGPUSparseMatrixCSC )
266+ if size (dst) != size (src)
267+ throw (ArgumentError (" Inconsistent Sparse Matrix size" ))
268+ end
269+ resize! (SparseArrays. getcolptr (dst), length (SparseArrays. getcolptr (src)))
270+ resize! (SparseArrays. rowvals (dst), length (SparseArrays. rowvals (src)))
271+ resize! (SparseArrays. nonzeros (dst), length (SparseArrays. nonzeros (src)))
272+ copyto! (SparseArrays. getcolptr (dst), SparseArrays. getcolptr (src))
273+ copyto! (SparseArrays. rowvals (dst), SparseArrays. rowvals (src))
274+ copyto! (SparseArrays. nonzeros (dst), SparseArrays. nonzeros (src))
275+ dst. nnz = src. nnz
276+ dst
277+ end
278+
210279# ## BROADCAST
211280
212281# broadcast container type promotion for combinations of sparse arrays and other types
@@ -749,12 +818,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
749818 offsets = rowPtr = sparse_arg. rowPtr
750819 colVal = similar (sparse_arg. colVal)
751820 nzVal = similar (sparse_arg. nzVal, Tv)
752- output = _sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size (bc))
821+ output = sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size (bc))
753822 elseif sparse_typ <: AbstractGPUSparseMatrixCSC
754823 offsets = colPtr = sparse_arg. colPtr
755824 rowVal = similar (sparse_arg. rowVal)
756825 nzVal = similar (sparse_arg. nzVal, Tv)
757- output = _sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size (bc))
826+ output = sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size (bc))
758827 end
759828 else
760829 # determine the number of non-zero elements per row so that we can create an
@@ -803,15 +872,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
803872 output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC}
804873 ixVal = similar (offsets, Ti, total_nnz)
805874 nzVal = similar (offsets, Tv, total_nnz)
806- output_sparse_typ = _sparse_array_type (sparse_typ)
875+ output_sparse_typ = sparse_array_type (sparse_typ)
807876 output_sparse_typ (offsets, ixVal, nzVal, size (bc))
808877 elseif sparse_typ <: AbstractGPUSparseVector && ! fpreszeros
809878 val_array = bc. args[first (sparse_args)]. nzVal
810879 similar (val_array, Tv, size (bc))
811880 elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros
812881 iPtr = similar (offsets, Ti, total_nnz)
813882 nzVal = similar (offsets, Tv, total_nnz)
814- _sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
883+ sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
815884 end
816885 if sparse_typ <: AbstractGPUSparseVector && ! fpreszeros
817886 nonsparse_args = map (bc. args) do arg
@@ -932,9 +1001,9 @@ function Base.mapreduce(f, op, A::AbstractGPUSparseMatrix; dims=:, init=nothing)
9321001 in (dims, [Colon (), 1 , 2 ]) || error (" only dims=:, dims=1 or dims=2 is supported" )
9331002
9341003 if A isa AbstractGPUSparseMatrixCSR && dims == 1
935- A = _csc_type (A)(A)
1004+ A = csc_type (A)(A)
9361005 elseif A isa AbstractGPUSparseMatrixCSC && dims == 2
937- A = _csr_type (A)(A)
1006+ A = csr_type (A)(A)
9381007 end
9391008 m, n = size (A)
9401009 val_array = nonzeros (A)
0 commit comments