Skip to content

Commit 80685e9

Browse files
committed
ML/FIML: workaround generic_matmul issue
1 parent 5305ca8 commit 80685e9

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

src/loss/ML/FIML.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,18 @@ end
164164
function ∇F_fiml_outer!(G, JΣ, Jμ, fiml::SemFIML)
165165
implied = fiml.implied
166166

167+
I_A⁻¹ = parent(implied.I_A⁻¹)
168+
F⨉I_A⁻¹ = parent(implied.F * I_A⁻¹)
169+
S = parent(implied.S)
170+
167171
Iₙ = sparse(1.0I, size(implied.A)...)
168-
P = kron(implied.F⨉I_A⁻¹, implied.F⨉I_A⁻¹)
169-
Q = kron(implied.S * implied.I_A⁻¹', Iₙ)
172+
P = kron(F⨉I_A⁻¹, F⨉I_A⁻¹)
173+
Q = kron(S * I_A⁻¹', Iₙ)
170174
Q .+= fiml.commutator * Q
171175

172176
∇Σ = P * (implied.∇S + Q * implied.∇A)
173177

174-
∇μ =
175-
implied.F⨉I_A⁻¹ * implied.∇M +
176-
kron((implied.I_A⁻¹ * implied.M)', implied.F⨉I_A⁻¹) * implied.∇A
178+
∇μ = F⨉I_A⁻¹ * implied.∇M + kron((I_A⁻¹ * implied.M)', F⨉I_A⁻¹) * implied.∇A
177179

178180
mul!(G, ∇Σ', JΣ) # actually transposed
179181
mul!(G, ∇μ', Jμ, -1, 1)

src/loss/ML/ML.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ function evaluate!(objective, gradient, hessian, ml::SemML, par)
184184
end
185185

186186
if !isnothing(gradient)
187-
S = implied.S
188-
F⨉I_A⁻¹ = implied.F⨉I_A⁻¹
189-
I_A⁻¹ = implied.I_A⁻¹
187+
S = parent(implied.S)
188+
F⨉I_A⁻¹ = parent(implied.F⨉I_A⁻¹)
189+
I_A⁻¹ = parent(implied.I_A⁻¹)
190190
∇A = implied.∇A
191191
∇S = implied.∇S
192192

@@ -198,15 +198,9 @@ function evaluate!(objective, gradient, hessian, ml::SemML, par)
198198
C = mul!(
199199
ml.varXvar_1,
200200
F⨉I_A⁻¹',
201-
mul!(ml.obsXvar_1, Symmetric(mul!(ml.obsXobs_3, one_Σ⁻¹Σₒ, Σ⁻¹)), F⨉I_A⁻¹),
202-
)
203-
mul!(
204-
gradient,
205-
∇A',
206-
vec(mul!(ml.varXvar_3, Symmetric(C), mul!(ml.varXvar_2, S, I_A⁻¹'))),
207-
2,
208-
0,
201+
mul!(ml.obsXvar_1, mul!(ml.obsXobs_3, one_Σ⁻¹Σₒ, Σ⁻¹), F⨉I_A⁻¹),
209202
)
203+
mul!(gradient, ∇A', vec(mul!(ml.varXvar_3, C, mul!(ml.varXvar_2, S, I_A⁻¹'))), 2, 0)
210204
mul!(gradient, ∇S', vec(C), 1, 1)
211205

212206
if MeanStruct(implied) === HasMeanStruct

0 commit comments

Comments
 (0)