Skip to content

Commit c4af0d7

Browse files
Alexey Stukalovalyst
authored andcommitted
EM: optimize mean handling
1 parent af9cb2f commit c4af0d7

File tree

1 file changed

+46
-26
lines changed

1 file changed

+46
-26
lines changed

src/observed/EM.jl

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ via expectation maximization (EM) for `observed`.
1717
1818
Returns the tuple of the EM covariance matrix and the EM mean vector.
1919
20-
Uses the EM algorithm for MVN-distributed data with missing values
20+
Based on the EM algorithm for MVN-distributed data with missing values
2121
adapted from the supplementary material to the book *Machine Learning: A Probabilistic Perspective*,
2222
copyright (2010) Kevin Murphy and Matt Dunham: see
2323
[*gaussMissingFitEm.m*](https://github.com/probml/pmtk3/blob/master/toolbox/BasicModels/gauss/sub/gaussMissingFitEm.m) and
@@ -106,6 +106,8 @@ function em_step!(
106106
copy!(μ, 𝔼x_full)
107107
copy!(Σ, 𝔼xxᵀ_full)
108108
nobs_used = nobs_full
109+
mul!(Σ, μ₀, μ₀', -nobs_used, 1)
110+
axpy!(-nobs_used, μ₀, μ)
109111

110112
# Compute the expected sufficient statistics
111113
for pat in patterns
@@ -115,50 +117,68 @@ function em_step!(
115117
u = pat.miss_mask
116118
o = pat.obs_mask
117119

118-
# precompute for pattern
119-
Σoo_chol = cholesky(Symmetric(Σ₀[o, o]))
120-
Σuo = Σ₀[u, o]
121-
μu = μ₀[u]
122-
μo = μ₀[o]
120+
# compute cholesky to speed-up ldiv!()
121+
Σ₀oo_chol = cholesky(Symmetric(Σ₀[o, o]))
122+
Σ₀uo = Σ₀[u, o]
123+
μ₀u = μ₀[u]
124+
μ₀o = μ₀[o]
123125

124126
# get pattern observations
125127
nobs = !isnothing(max_nobs_em) ? min(max_nobs_em, n_obs(pat)) : n_obs(pat)
126-
pat_data =
128+
zo =
127129
nobs < n_obs(pat) ?
128-
view(pat.data, :, sort!(sample(1:n_obs(pat), nobs, replace = false))) : pat.data
130+
pat.data[:, sort!(sample(1:n_obs(pat), nobs, replace = false))] : copy(pat.data)
131+
zo .-= μ₀o # subtract current mean from observations
129132

130-
𝔼xu = fill!(similar(μu), 0)
131-
𝔼xo = fill!(similar(μo), 0)
132-
𝔼xᵢu = similar(μu)
133+
𝔼zo = sum(zo, dims = 2)
134+
𝔼zu = fill!(similar(μ₀u), 0)
133135

134-
𝔼xxᵀuo = fill!(similar(Σuo), 0)
135-
𝔼xxᵀuu = n_obs(pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo'))
136+
𝔼zzᵀuo = fill!(similar(Σ₀uo), 0)
137+
𝔼zzᵀuu = nobs * Σ₀[u, u]
138+
mul!(𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo', -nobs, 1)
136139

137140
# loop through observations
138-
@inbounds for obsdata in eachcol(pat_data)
139-
mul!(𝔼xᵢu, Σuo, Σoo_chol \ (obsdata - μo))
140-
𝔼xᵢu .+= μu
141-
mul!(𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu', 1, 1)
142-
mul!(𝔼xxᵀuo, 𝔼xᵢu, obsdata', 1, 1)
143-
𝔼xu .+= 𝔼xᵢu
144-
𝔼xo .+= obsdata
141+
yᵢo = similar(μ₀o)
142+
𝔼zᵢu = similar(μ₀u)
143+
@inbounds for zᵢo in eachcol(zo)
144+
ldiv!(yᵢo, Σ₀oo_chol, zᵢo)
145+
mul!(𝔼zᵢu, Σ₀uo, yᵢo)
146+
mul!(𝔼zzᵀuu, 𝔼zᵢu, 𝔼zᵢu', 1, 1)
147+
mul!(𝔼zzᵀuo, 𝔼zᵢu, zᵢo', 1, 1)
148+
𝔼zu .+= 𝔼zᵢu
145149
end
150+
# correct 𝔼zzᵀ by adding back μ₀×𝔼z' + 𝔼z'×μ₀
151+
mul!(𝔼zzᵀuo, μ₀u, 𝔼zo', 1, 1)
152+
mul!(𝔼zzᵀuo, 𝔼zu, μ₀o', 1, 1)
146153

147-
Σ[o, o] .+= pat_data * pat_data'
148-
Σ[u, o] .+= 𝔼xxᵀuo
149-
Σ[o, u] .+= 𝔼xxᵀuo'
150-
Σ[u, u] .+= 𝔼xxᵀuu
154+
mul!(𝔼zzᵀuu, μ₀u, 𝔼zu', 1, 1)
155+
mul!(𝔼zzᵀuu, 𝔼zu, μ₀u', 1, 1)
151156

152-
μ[o] .+= 𝔼xo
153-
μ[u] .+= 𝔼xu
157+
𝔼zzᵀoo = zo * zo'
158+
mul!(𝔼zzᵀoo, μ₀o, 𝔼zo', 1, 1)
159+
mul!(𝔼zzᵀoo, 𝔼zo, μ₀o', 1, 1)
160+
161+
# update Σ and μ
162+
Σ[o, o] .+= 𝔼zzᵀoo
163+
Σ[u, o] .+= 𝔼zzᵀuo
164+
Σ[o, u] .+= 𝔼zzᵀuo'
165+
Σ[u, u] .+= 𝔼zzᵀuu
166+
167+
μ[o] .+= 𝔼zo
168+
μ[u] .+= 𝔼zu
154169

155170
nobs_used += nobs
156171
end
157172

158173
# M step, update em_model
159174
lmul!(1 / nobs_used, Σ)
160175
lmul!(1 / nobs_used, μ)
176+
# at this point μ = μ - μ₀
177+
# and Σ = Σ + (μ - μ₀)×(μ - μ₀)' - μ₀×μ₀'
178+
mul!(Σ, μ, μ₀', -1, 1)
179+
mul!(Σ, μ₀, μ', -1, 1)
161180
mul!(Σ, μ, μ', -1, 1)
181+
μ .+= μ₀
162182

163183
# ridge Σ
164184
# while !isposdef(Σ)

0 commit comments

Comments
 (0)