Skip to content

Commit 8d3ee63

Browse files
committed
fix: tril/triu working again
1 parent f29548a commit 8d3ee63

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,26 +273,56 @@ function overloaded_mul!(
273273
return C
274274
end
275275

276-
function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
276+
if isdefined(LinearAlgebra, :_triu)
277+
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
278+
return overloaded_triu(materialize_traced_array(A), k)
279+
end
280+
function LinearAlgebra._triu(
281+
A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer
282+
) where {T}
283+
return overloaded_triu(materialize_traced_array(A), k)
284+
end
285+
end
286+
287+
if isdefined(LinearAlgebra, :_tril)
288+
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
289+
return overloaded_tril(materialize_traced_array(A), k)
290+
end
291+
function LinearAlgebra._tril(
292+
A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer
293+
) where {T}
294+
return overloaded_tril(materialize_traced_array(A), k)
295+
end
296+
end
297+
298+
function LinearAlgebra.triu!(X::AnyTracedRArray{T,2}, k::Integer) where {T}
299+
set_mlir_data!(X, overloaded_triu(materialize_traced_array(X), k))
300+
return X
301+
end
302+
303+
function LinearAlgebra.tril!(X::AnyTracedRArray{T,2}, k::Integer) where {T}
304+
set_mlir_data!(X, overloaded_tril(materialize_traced_array(X), k))
305+
return X
306+
end
307+
308+
function overloaded_triu(X::TracedRArray{T,2}, k::Integer) where {T}
277309
iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1)
278310
iota_2 = @opcall subtract(
279311
@opcall(iota(Int64, [size(X)...]; iota_dimension=2)),
280312
Reactant.broadcast_to_size(k, size(X)),
281313
)
282314
idxs = @opcall compare(iota_1, iota_2; comparison_direction="LE")
283-
X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data
284-
return X
315+
return @opcall select(idxs, X, zero(X))
285316
end
286317

287-
function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
318+
function overloaded_tril(X::TracedRArray{T,2}, k::Integer) where {T}
288319
iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1)
289320
iota_2 = @opcall subtract(
290321
@opcall(iota(Int64, [size(X)...]; iota_dimension=2)),
291322
Reactant.broadcast_to_size(k, size(X)),
292323
)
293324
idxs = @opcall compare(iota_1, iota_2; comparison_direction="GE")
294-
X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data
295-
return X
325+
return @opcall select(idxs, X, zero(X))
296326
end
297327

298328
# LinearAlgebra defines norm with some conditionals which cannot be traced directly

0 commit comments

Comments
 (0)