@@ -16,13 +16,36 @@ SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
1616SparseArrays. nonzeros (g:: T ) where  {T<: AbstractGPUSparseArray } =  g. nzVal
1717
1818SparseArrays. nonzeroinds (g:: T ) where  {T<: AbstractGPUSparseVector } =  g. iPtr
19- SparseArrays. rowvals (g:: T ) where  {T<: AbstractGPUSparseVector } =  nonzeroinds (g)
19+ SparseArrays. rowvals (g:: T ) where  {T<: AbstractGPUSparseVector } =  SparseArrays . nonzeroinds (g)
2020
2121SparseArrays. rowvals (g:: AbstractGPUSparseMatrixCSC ) =  g. rowVal
2222SparseArrays. getcolptr (S:: AbstractGPUSparseMatrixCSC ) =  S. colPtr
2323
2424Base. convert (T:: Type{<:AbstractGPUSparseArray} , m:: AbstractArray ) =  m isa  T ?  m :  T (m)
2525
26+ #  collect to Array
27+ Base. collect (x:: AbstractGPUSparseVector ) =  collect (SparseVector (x))
28+ Base. collect (x:: AbstractGPUSparseMatrixCSC ) =  collect (SparseMatrixCSC (x))
29+ Base. collect (x:: AbstractGPUSparseMatrixCSR ) =  collect (SparseMatrixCSC (x))
30+ Base. collect (x:: AbstractGPUSparseMatrixBSR ) =  collect (SparseMatrixCSC (x))
31+ Base. collect (x:: AbstractGPUSparseMatrixCOO ) =  collect (SparseMatrixCSC (x))
32+ 
33+ Base. Array (x:: AbstractGPUSparseVector )    =  collect (SparseVector (x))
34+ Base. Array (x:: AbstractGPUSparseMatrixCSC ) =  collect (SparseMatrixCSC (x))
35+ Base. Array (x:: AbstractGPUSparseMatrixCSR ) =  collect (SparseMatrixCSC (x))
36+ Base. Array (x:: AbstractGPUSparseMatrixBSR ) =  collect (SparseMatrixCSC (x))
37+ Base. Array (x:: AbstractGPUSparseMatrixCOO ) =  collect (SparseMatrixCSC (x))
38+ 
39+ SparseArrays. SparseVector (x:: AbstractGPUSparseVector ) =  SparseVector (length (x), Array (SparseArrays. nonzeroinds (x)), Array (SparseArrays. nonzeros (x)))
40+ SparseArrays. SparseMatrixCSC (x:: AbstractGPUSparseMatrixCSC ) =  SparseMatrixCSC (size (x)... , Array (SparseArrays. getcolptr (x)), Array (SparseArrays. rowvals (x)), Array (SparseArrays. nonzeros (x)))
41+ 
42+ #  similar
43+ Base. similar (Vec:: V ) where  {V<: AbstractGPUSparseVector } =  V (copy (SparseArrays. nonzeroinds (Vec)), similar (SparseArrays. nonzeros (Vec)), length (Vec))
44+ Base. similar (Mat:: M ) where  {M<: AbstractGPUSparseMatrixCSC } =  M (copy (SparseArrays. getcolptr (Mat)), copy (SparseArrays. rowvals (Mat)), similar (SparseArrays. nonzeros (Mat)), size (Mat))
45+ 
46+ Base. similar (Vec:: V , T:: Type ) where  {Tv, Ti, V<: AbstractGPUSparseVector{Tv, Ti} } =  sparse_array_type (V){T, Ti}(copy (SparseArrays. nonzeroinds (Vec)), similar (SparseArrays. nonzeros (Vec), T), length (Vec))
47+ Base. similar (Mat:: M , T:: Type ) where  {M<: AbstractGPUSparseMatrixCSC } =  sparse_array_type (M)(copy (SparseArrays. getcolptr (Mat)), copy (SparseArrays. rowvals (Mat)), similar (SparseArrays. nonzeros (Mat), T), size (Mat))
48+ 
2649dense_array_type (sa:: SparseVector )     =  SparseVector
2750dense_array_type (:: Type{SparseVector} ) =  SparseVector
2851sparse_array_type (sa:: SparseVector ) =  SparseVector
@@ -207,6 +230,52 @@ Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon)       = getindex(A, i, 1:s
207230Base. getindex (A:: AbstractGPUSparseMatrix , :: Colon , i)       =  getindex (A, 1 : size (A, 1 ), i)
208231Base. getindex (A:: AbstractGPUSparseMatrix , I:: Tuple{Integer,Integer} ) =  getindex (A, I[1 ], I[2 ])
209232
233+ function  Base. getindex (A:: AbstractGPUSparseVector{Tv, Ti} , i:: Integer ) where  {Tv, Ti}
234+     @boundscheck  checkbounds (A, i)
235+     ii =  searchsortedfirst (SparseArrays. nonzeroinds (A), convert (Ti, i))
236+     (ii >  SparseArrays. nnz (A) ||  SparseArrays. nonzeroinds (A)[ii] !=  i) &&  return  zero (Tv)
237+     SparseArrays. nonzeros (A)[ii]
238+ end 
239+ 
240+ function  Base. getindex (A:: AbstractGPUSparseMatrixCSC{T} , i0:: Integer , i1:: Integer ) where  T
241+     @boundscheck  checkbounds (A, i0, i1)
242+     r1 =  Int (SparseArrays. getcolptr (A)[i1])
243+     r2 =  Int (SparseArrays. getcolptr (A)[i1+ 1 ]- 1 )
244+     (r1 >  r2) &&  return  zero (T)
245+     r1 =  searchsortedfirst (SparseArrays. rowvals (A), i0, r1, r2, Base. Order. Forward)
246+     (r1 >  r2 ||  SparseArrays. rowvals (A)[r1] !=  i0) &&  return  zero (T)
247+     SparseArrays. nonzeros (A)[r1]
248+ end 
249+ 
250+ # # copying between sparse GPU arrays
251+ Base. copy (Vec:: AbstractGPUSparseVector ) =  copyto! (similar (Vec), Vec)
252+ 
253+ function  Base. copyto! (dst:: AbstractGPUSparseVector , src:: AbstractGPUSparseVector )
254+     if  length (dst) !=  length (src)
255+         throw (ArgumentError (" Inconsistent Sparse Vector size" 
256+     end 
257+     resize! (SparseArrays. nonzeroinds (dst), length (SparseArrays. nonzeroinds (src)))
258+     resize! (SparseArrays. nonzeros (dst), length (SparseArrays. nonzeros (src)))
259+     copyto! (SparseArrays. nonzeroinds (dst), SparseArrays. nonzeroinds (src))
260+     copyto! (SparseArrays. nonzeros (dst), SparseArrays. nonzeros (src))
261+     dst. nnz =  src. nnz
262+     dst
263+ end 
264+ 
265+ function  Base. copyto! (dst:: AbstractGPUSparseMatrixCSC , src:: AbstractGPUSparseMatrixCSC )
266+     if  size (dst) !=  size (src)
267+         throw (ArgumentError (" Inconsistent Sparse Matrix size" 
268+     end 
269+     resize! (SparseArrays. getcolptr (dst), length (SparseArrays. getcolptr (src)))
270+     resize! (SparseArrays. rowvals (dst), length (SparseArrays. rowvals (src)))
271+     resize! (SparseArrays. nonzeros (dst), length (SparseArrays. nonzeros (src)))
272+     copyto! (SparseArrays. getcolptr (dst), SparseArrays. getcolptr (src))
273+     copyto! (SparseArrays. rowvals (dst), SparseArrays. rowvals (src))
274+     copyto! (SparseArrays. nonzeros (dst), SparseArrays. nonzeros (src))
275+     dst. nnz =  src. nnz
276+     dst
277+ end 
278+ 
210279# ## BROADCAST
211280
212281#  broadcast container type promotion for combinations of sparse arrays and other types
@@ -749,12 +818,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
749818            offsets =  rowPtr =  sparse_arg. rowPtr
750819            colVal  =  similar (sparse_arg. colVal)
751820            nzVal   =  similar (sparse_arg. nzVal, Tv)
752-             output  =  _sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size (bc))
821+             output  =  sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size (bc))
753822        elseif  sparse_typ <:  AbstractGPUSparseMatrixCSC 
754823            offsets =  colPtr =  sparse_arg. colPtr
755824            rowVal  =  similar (sparse_arg. rowVal)
756825            nzVal   =  similar (sparse_arg. nzVal, Tv)
757-             output  =  _sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size (bc))
826+             output  =  sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size (bc))
758827        end 
759828    else 
760829        #  determine the number of non-zero elements per row so that we can create an
@@ -803,15 +872,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
803872        output =  if  sparse_typ <:  Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC} 
804873            ixVal =  similar (offsets, Ti, total_nnz)
805874            nzVal =  similar (offsets, Tv, total_nnz)
806-             output_sparse_typ =  _sparse_array_type (sparse_typ) 
875+             output_sparse_typ =  sparse_array_type (sparse_typ) 
807876            output_sparse_typ (offsets, ixVal, nzVal, size (bc))
808877        elseif  sparse_typ <:  AbstractGPUSparseVector  &&  ! fpreszeros
809878            val_array =  bc. args[first (sparse_args)]. nzVal
810879            similar (val_array, Tv, size (bc))
811880        elseif  sparse_typ <:  AbstractGPUSparseVector  &&  fpreszeros
812881            iPtr   =  similar (offsets, Ti, total_nnz)
813882            nzVal  =  similar (offsets, Tv, total_nnz)
814-             _sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
883+             sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
815884        end 
816885        if  sparse_typ <:  AbstractGPUSparseVector  &&  ! fpreszeros
817886            nonsparse_args =  map (bc. args) do  arg
@@ -932,9 +1001,9 @@ function Base.mapreduce(f, op, A::AbstractGPUSparseMatrix; dims=:, init=nothing)
9321001    in (dims, [Colon (), 1 , 2 ]) ||  error (" only dims=:, dims=1 or dims=2 is supported" 
9331002
9341003    if  A isa  AbstractGPUSparseMatrixCSR &&  dims ==  1 
935-         A =  _csc_type (A)(A)
1004+         A =  csc_type (A)(A)
9361005    elseif  A isa  AbstractGPUSparseMatrixCSC &&  dims ==  2 
937-         A =  _csr_type (A)(A)
1006+         A =  csr_type (A)(A)
9381007    end 
9391008    m, n      =  size (A)
9401009    val_array =  nonzeros (A)
0 commit comments