Skip to content

Commit af9cb2f

Browse files
Alexey Stukalovalyst
authored andcommitted
EM: max_nobs_em opt to limit obs used
1 parent 0cd1f52 commit af9cb2f

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

src/observed/EM.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ function em_mvn(
2828
start_em = start_em_observed,
2929
max_iter_em::Integer = 100,
3030
rtol_em::Number = 1e-4,
31+
max_nobs_em::Union{Integer, Nothing} = nothing,
3132
kwargs...,
3233
)
3334
n_man = SEM.n_man(patterns[1])
@@ -58,7 +59,7 @@ function em_mvn(
5859
Δμ_rel = NaN
5960
ΔΣ_rel = NaN
6061
while !converged && (iter < max_iter_em)
61-
em_step!(Σ, μ, Σ_prev, μ_prev, patterns, 𝔼x_full, 𝔼xxᵀ_full)
62+
em_step!(Σ, μ, Σ_prev, μ_prev, patterns, 𝔼xxᵀ_full, 𝔼x_full, nobs_full; max_nobs_em)
6263

6364
if iter > 0
6465
Δμ = norm- μ_prev)
@@ -96,16 +97,19 @@ function em_step!(
9697
Σ₀::AbstractMatrix,
9798
μ₀::AbstractVector,
9899
patterns::AbstractVector{<:SemObservedMissingPattern},
99-
𝔼x_full,
100-
𝔼xxᵀ_full,
100+
𝔼xxᵀ_full::AbstractMatrix,
101+
𝔼x_full::AbstractVector,
102+
nobs_full::Integer;
103+
max_nobs_em::Union{Integer, Nothing} = nothing,
101104
)
102105
# E step, update 𝔼x and 𝔼xxᵀ
103106
copy!(μ, 𝔼x_full)
104107
copy!(Σ, 𝔼xxᵀ_full)
108+
nobs_used = nobs_full
105109

106110
# Compute the expected sufficient statistics
107111
for pat in patterns
108-
(nmissed_vars(pat) == 0) && continue # skip full cases
112+
(nmissed_vars(pat) == 0) && continue # full cases already accounted for
109113

110114
# observed and unobserved vars
111115
u = pat.miss_mask
@@ -117,6 +121,12 @@ function em_step!(
117121
μu = μ₀[u]
118122
μo = μ₀[o]
119123

124+
# get pattern observations
125+
nobs = !isnothing(max_nobs_em) ? min(max_nobs_em, n_obs(pat)) : n_obs(pat)
126+
pat_data =
127+
nobs < n_obs(pat) ?
128+
view(pat.data, :, sort!(sample(1:n_obs(pat), nobs, replace = false))) : pat.data
129+
120130
𝔼xu = fill!(similar(μu), 0)
121131
𝔼xo = fill!(similar(μo), 0)
122132
𝔼xᵢu = similar(μu)
@@ -125,28 +135,29 @@ function em_step!(
125135
𝔼xxᵀuu = n_obs(pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo'))
126136

127137
# loop through observations
128-
@inbounds for rowdata in eachcol(pat.data)
129-
mul!(𝔼xᵢu, Σuo, Σoo_chol \ (rowdata - μo))
138+
@inbounds for obsdata in eachcol(pat_data)
139+
mul!(𝔼xᵢu, Σuo, Σoo_chol \ (obsdata - μo))
130140
𝔼xᵢu .+= μu
131141
mul!(𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu', 1, 1)
132-
mul!(𝔼xxᵀuo, 𝔼xᵢu, rowdata', 1, 1)
142+
mul!(𝔼xxᵀuo, 𝔼xᵢu, obsdata', 1, 1)
133143
𝔼xu .+= 𝔼xᵢu
134-
𝔼xo .+= rowdata
144+
𝔼xo .+= obsdata
135145
end
136146

137-
Σ[o, o] .+= pat.data' * pat.data
147+
Σ[o, o] .+= pat_data * pat_data'
138148
Σ[u, o] .+= 𝔼xxᵀuo
139149
Σ[o, u] .+= 𝔼xxᵀuo'
140150
Σ[u, u] .+= 𝔼xxᵀuu
141151

142152
μ[o] .+= 𝔼xo
143153
μ[u] .+= 𝔼xu
154+
155+
nobs_used += nobs
144156
end
145157

146158
# M step, update em_model
147-
k = inv(sum(n_obs, patterns))
148-
lmul!(k, Σ)
149-
lmul!(k, μ)
159+
lmul!(1 / nobs_used, Σ)
160+
lmul!(1 / nobs_used, μ)
150161
mul!(Σ, μ, μ', -1, 1)
151162

152163
# ridge Σ

0 commit comments

Comments
 (0)