@@ -28,6 +28,7 @@ function em_mvn(
2828 start_em = start_em_observed,
2929 max_iter_em:: Integer = 100 ,
3030 rtol_em:: Number = 1e-4 ,
31+ max_nobs_em:: Union{Integer, Nothing} = nothing ,
3132 kwargs... ,
3233)
3334 n_man = SEM. n_man (patterns[1 ])
@@ -58,7 +59,7 @@ function em_mvn(
5859 Δμ_rel = NaN
5960 ΔΣ_rel = NaN
6061 while ! converged && (iter < max_iter_em)
61- em_step! (Σ, μ, Σ_prev, μ_prev, patterns, 𝔼x_full, 𝔼xxᵀ_full )
62+ em_step! (Σ, μ, Σ_prev, μ_prev, patterns, 𝔼xxᵀ_full, 𝔼x_full, nobs_full; max_nobs_em )
6263
6364 if iter > 0
6465 Δμ = norm (μ - μ_prev)
@@ -96,16 +97,19 @@ function em_step!(
9697 Σ₀:: AbstractMatrix ,
9798 μ₀:: AbstractVector ,
9899 patterns:: AbstractVector{<:SemObservedMissingPattern} ,
99- 𝔼x_full,
100- 𝔼xxᵀ_full,
100+ 𝔼xxᵀ_full:: AbstractMatrix ,
101+ 𝔼x_full:: AbstractVector ,
102+ nobs_full:: Integer ;
103+ max_nobs_em:: Union{Integer, Nothing} = nothing ,
101104)
102105 # E step, update 𝔼x and 𝔼xxᵀ
103106 copy! (μ, 𝔼x_full)
104107 copy! (Σ, 𝔼xxᵀ_full)
108+ nobs_used = nobs_full
105109
106110 # Compute the expected sufficient statistics
107111 for pat in patterns
108- (nmissed_vars (pat) == 0 ) && continue # skip full cases
112+ (nmissed_vars (pat) == 0 ) && continue # full cases already accounted for
109113
110114 # observed and unobserved vars
111115 u = pat. miss_mask
@@ -117,6 +121,12 @@ function em_step!(
117121 μu = μ₀[u]
118122 μo = μ₀[o]
119123
124+ # get pattern observations
125+ nobs = ! isnothing (max_nobs_em) ? min (max_nobs_em, n_obs (pat)) : n_obs (pat)
126+ pat_data =
127+ nobs < n_obs (pat) ?
128+ view (pat. data, :, sort! (sample (1 : n_obs (pat), nobs, replace = false ))) : pat. data
129+
120130 𝔼xu = fill! (similar (μu), 0 )
121131 𝔼xo = fill! (similar (μo), 0 )
122132 𝔼xᵢu = similar (μu)
@@ -125,28 +135,29 @@ function em_step!(
125135 𝔼xxᵀuu = n_obs (pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo' ))
126136
127137 # loop through observations
128- @inbounds for rowdata in eachcol (pat . data )
129- mul! (𝔼xᵢu, Σuo, Σoo_chol \ (rowdata - μo))
138+ @inbounds for obsdata in eachcol (pat_data )
139+ mul! (𝔼xᵢu, Σuo, Σoo_chol \ (obsdata - μo))
130140 𝔼xᵢu .+ = μu
131141 mul! (𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu' , 1 , 1 )
132- mul! (𝔼xxᵀuo, 𝔼xᵢu, rowdata ' , 1 , 1 )
142+ mul! (𝔼xxᵀuo, 𝔼xᵢu, obsdata ' , 1 , 1 )
133143 𝔼xu .+ = 𝔼xᵢu
134- 𝔼xo .+ = rowdata
144+ 𝔼xo .+ = obsdata
135145 end
136146
137- Σ[o, o] .+ = pat . data ' * pat . data
147+ Σ[o, o] .+ = pat_data * pat_data '
138148 Σ[u, o] .+ = 𝔼xxᵀuo
139149 Σ[o, u] .+ = 𝔼xxᵀuo'
140150 Σ[u, u] .+ = 𝔼xxᵀuu
141151
142152 μ[o] .+ = 𝔼xo
143153 μ[u] .+ = 𝔼xu
154+
155+ nobs_used += nobs
144156 end
145157
146158 # M step, update em_model
147- k = inv (sum (n_obs, patterns))
148- lmul! (k, Σ)
149- lmul! (k, μ)
159+ lmul! (1 / nobs_used, Σ)
160+ lmul! (1 / nobs_used, μ)
150161 mul! (Σ, μ, μ' , - 1 , 1 )
151162
152163 # ridge Σ
0 commit comments