Skip to content

Commit 7f11083

Browse files
authored
Another mul specialization for diagonals (#627)
* Another mul specialization for diagonals * Fix dimension check
1 parent 8272e39 commit 7f11083

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/host/linalg.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,21 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
258258
return C
259259
end
260260

261+
function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
262+
A::AbstractGPUArray,
263+
B::AbstractGPUArray)
264+
dc = C.diag
265+
d = length(dc)
266+
m, n = size(A, 1), size(A, 2)
267+
m′, n′ = size(B, 1), size(B, 2)
268+
m == d || throw(DimensionMismatch("left hand side has $m rows but output is $d by $d"))
269+
n′ == d || throw(DimensionMismatch("right hand side has $n′ cols but output is $d by $d"))
270+
C_ = A * B
271+
isdiag(C_) || throw(ErrorException("output matrix must be diagonal"))
272+
dc .= diag(C_)
273+
return C
274+
end
275+
261276
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
262277
D::Diagonal{<:Any, <:AbstractGPUArray},
263278
A::AbstractGPUVecOrMat)

test/testsuite/linalg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@
250250
A = Diagonal(a)
251251
mul!(C, A, B)
252252
@test collect(C.diag) collect(A.diag) .* collect(B.diag)
253+
a = AT(diagm(rand(elty, n)))
254+
b = AT(diagm(rand(elty, n)))
255+
C = Diagonal(d)
256+
mul!(C, a, b)
257+
@test collect(C) Diagonal(collect(a) * collect(b))
253258
end
254259
end
255260

0 commit comments

Comments
 (0)