@@ -17,7 +17,7 @@ via expectation maximization (EM) for `observed`.
1717
1818Returns the tuple of the EM covariance matrix and the EM mean vector.
1919
20- Uses the EM algorithm for MVN-distributed data with missing values
20+ Based on the EM algorithm for MVN-distributed data with missing values
2121adapted from the supplementary material to the book *Machine Learning: A Probabilistic Perspective*,
2222copyright (2010) Kevin Murphy and Matt Dunham: see
2323[*gaussMissingFitEm.m*](https://github.com/probml/pmtk3/blob/master/toolbox/BasicModels/gauss/sub/gaussMissingFitEm.m) and
@@ -106,6 +106,8 @@ function em_step!(
106106 copy! (μ, 𝔼x_full)
107107 copy! (Σ, 𝔼xxᵀ_full)
108108 nobs_used = nobs_full
109+ mul! (Σ, μ₀, μ₀' , - nobs_used, 1 )
110+ axpy! (- nobs_used, μ₀, μ)
109111
110112 # Compute the expected sufficient statistics
111113 for pat in patterns
@@ -115,50 +117,68 @@ function em_step!(
115117 u = pat. miss_mask
116118 o = pat. obs_mask
117119
118- # precompute for pattern
119- Σoo_chol = cholesky (Symmetric (Σ₀[o, o]))
120- Σuo = Σ₀[u, o]
121- μu = μ₀[u]
122- μo = μ₀[o]
120+ # compute cholesky to speed-up ldiv!()
121+ Σ₀oo_chol = cholesky (Symmetric (Σ₀[o, o]))
122+ Σ₀uo = Σ₀[u, o]
123+ μ₀u = μ₀[u]
124+ μ₀o = μ₀[o]
123125
124126 # get pattern observations
125127 nobs = ! isnothing (max_nobs_em) ? min (max_nobs_em, n_obs (pat)) : n_obs (pat)
126- pat_data =
128+ zo =
127129 nobs < n_obs (pat) ?
128- view (pat. data, :, sort! (sample (1 : n_obs (pat), nobs, replace = false ))) : pat. data
130+ pat. data[:, sort! (sample (1 : n_obs (pat), nobs, replace = false ))] : copy (pat. data)
131+ zo .- = μ₀o # subtract current mean from observations
129132
130- 𝔼xu = fill! (similar (μu), 0 )
131- 𝔼xo = fill! (similar (μo), 0 )
132- 𝔼xᵢu = similar (μu)
133+ 𝔼zo = sum (zo, dims = 2 )
134+ 𝔼zu = fill! (similar (μ₀u), 0 )
133135
134- 𝔼xxᵀuo = fill! (similar (Σuo), 0 )
135- 𝔼xxᵀuu = n_obs (pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo' ))
136+ 𝔼zzᵀuo = fill! (similar (Σ₀uo), 0 )
137+ 𝔼zzᵀuu = nobs * Σ₀[u, u]
138+ mul! (𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo' , - nobs, 1 )
136139
137140 # loop through observations
138- @inbounds for obsdata in eachcol (pat_data)
139- mul! (𝔼xᵢu, Σuo, Σoo_chol \ (obsdata - μo))
140- 𝔼xᵢu .+ = μu
141- mul! (𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu' , 1 , 1 )
142- mul! (𝔼xxᵀuo, 𝔼xᵢu, obsdata' , 1 , 1 )
143- 𝔼xu .+ = 𝔼xᵢu
144- 𝔼xo .+ = obsdata
141+ yᵢo = similar (μ₀o)
142+ 𝔼zᵢu = similar (μ₀u)
143+ @inbounds for zᵢo in eachcol (zo)
144+ ldiv! (yᵢo, Σ₀oo_chol, zᵢo)
145+ mul! (𝔼zᵢu, Σ₀uo, yᵢo)
146+ mul! (𝔼zzᵀuu, 𝔼zᵢu, 𝔼zᵢu' , 1 , 1 )
147+ mul! (𝔼zzᵀuo, 𝔼zᵢu, zᵢo' , 1 , 1 )
148+ 𝔼zu .+ = 𝔼zᵢu
145149 end
150+ # correct 𝔼zzᵀ by adding back μ₀×𝔼z' + 𝔼z'×μ₀
151+ mul! (𝔼zzᵀuo, μ₀u, 𝔼zo' , 1 , 1 )
152+ mul! (𝔼zzᵀuo, 𝔼zu, μ₀o' , 1 , 1 )
146153
147- Σ[o, o] .+ = pat_data * pat_data'
148- Σ[u, o] .+ = 𝔼xxᵀuo
149- Σ[o, u] .+ = 𝔼xxᵀuo'
150- Σ[u, u] .+ = 𝔼xxᵀuu
154+ mul! (𝔼zzᵀuu, μ₀u, 𝔼zu' , 1 , 1 )
155+ mul! (𝔼zzᵀuu, 𝔼zu, μ₀u' , 1 , 1 )
151156
152- μ[o] .+ = 𝔼xo
153- μ[u] .+ = 𝔼xu
157+ 𝔼zzᵀoo = zo * zo'
158+ mul! (𝔼zzᵀoo, μ₀o, 𝔼zo' , 1 , 1 )
159+ mul! (𝔼zzᵀoo, 𝔼zo, μ₀o' , 1 , 1 )
160+
161+ # update Σ and μ
162+ Σ[o, o] .+ = 𝔼zzᵀoo
163+ Σ[u, o] .+ = 𝔼zzᵀuo
164+ Σ[o, u] .+ = 𝔼zzᵀuo'
165+ Σ[u, u] .+ = 𝔼zzᵀuu
166+
167+ μ[o] .+ = 𝔼zo
168+ μ[u] .+ = 𝔼zu
154169
155170 nobs_used += nobs
156171 end
157172
158173 # M step, update em_model
159174 lmul! (1 / nobs_used, Σ)
160175 lmul! (1 / nobs_used, μ)
176+ # at this point μ = μ - μ₀
177+ # and Σ = Σ + (μ - μ₀)×(μ - μ₀)' - μ₀×μ₀'
178+ mul! (Σ, μ, μ₀' , - 1 , 1 )
179+ mul! (Σ, μ₀, μ' , - 1 , 1 )
161180 mul! (Σ, μ, μ' , - 1 , 1 )
181+ μ .+ = μ₀
162182
163183 # ridge Σ
164184 # while !isposdef(Σ)
0 commit comments