Skip to content

Commit 8b0322f

Browse files
Merge pull request #225 from StructuralEquationModels/simulation-wrapper
Add data simulation function
2 parents 54354cd + b7c111d commit 8b0322f

File tree

3 files changed

+123
-2
lines changed

3 files changed

+123
-2
lines changed

src/additional_functions/simulation.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Return a new model with swaped observed part.
77
88
# Arguments
9-
- `model::AbstractSemSingle`: optimization algorithm.
9+
- `model::AbstractSemSingle`: model to swap the observed part of.
1010
- `kwargs`: additional keyword arguments; typically includes `data = ...`
1111
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
1212
@@ -98,3 +98,44 @@ function update_observed(loss::SemLoss, new_observed; kwargs...)
9898
)
9999
return SemLoss(new_functions, loss.weights)
100100
end
101+
102+
############################################################################################
103+
# simulate data
104+
############################################################################################
105+
"""
106+
(1) rand(model::AbstractSemSingle, params, n)
107+
108+
(2) rand(model::AbstractSemSingle, n)
109+
110+
Sample normally distributed data from the model-implied covariance matrix and mean vector.
111+
112+
# Arguments
113+
- `model::AbstractSemSingle`: model to simulate from.
114+
- `params`: parameter values to simulate from.
115+
- `n::Integer`: Number of samples.
116+
117+
# Examples
118+
```julia
119+
rand(model, start_simple(model), 100)
120+
```
121+
"""
122+
function Distributions.rand(
123+
model::AbstractSemSingle{O, I, L, D},
124+
params,
125+
n::Integer,
126+
) where {O, I <: Union{RAM, RAMSymbolic}, L, D}
127+
update!(EvaluationTargets{true, false, false}(), model.imply, model, params)
128+
return rand(model, n)
129+
end
130+
131+
function Distributions.rand(
132+
model::AbstractSemSingle{O, I, L, D},
133+
n::Integer,
134+
) where {O, I <: Union{RAM, RAMSymbolic}, L, D}
135+
if MeanStruct(model.imply) === NoMeanStruct
136+
data = permutedims(rand(MvNormal(Symmetric(model.imply.Σ)), n))
137+
elseif MeanStruct(model.imply) === HasMeanStruct
138+
data = permutedims(rand(MvNormal(model.imply.μ, Symmetric(model.imply.Σ)), n))
139+
end
140+
return data
141+
end

test/examples/political_democracy/constructor.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Statistics: cov, mean
2+
using Random
23

34
############################################################################################
45
### models w.o. meanstructure
@@ -161,6 +162,44 @@ end
161162
)
162163
end
163164

165+
############################################################################################
166+
### data simulation
167+
############################################################################################
168+
169+
@testset "data_simulation_wo_mean" begin
170+
# parameters to recover
171+
params = start_simple(
172+
model_ml;
173+
start_loadings = 0.5,
174+
start_regressions = 0.5,
175+
start_variances_observed = 0.5,
176+
start_variances_latent = 1.0,
177+
start_covariances_observed = 0.2,
178+
)
179+
# set seed for simulation
180+
Random.seed!(83472834)
181+
colnames = Symbol.(names(example_data("political_democracy")))
182+
# simulate data
183+
model_ml_new = swap_observed(
184+
model_ml,
185+
data = rand(model_ml, params, 1_000_000),
186+
specification = spec,
187+
obs_colnames = colnames,
188+
)
189+
model_ml_sym_new = swap_observed(
190+
model_ml_sym,
191+
data = rand(model_ml_sym, params, 1_000_000),
192+
specification = spec,
193+
obs_colnames = colnames,
194+
)
195+
# fit models
196+
sol_ml = solution(sem_fit(model_ml_new))
197+
sol_ml_sym = solution(sem_fit(model_ml_sym_new))
198+
# check solution
199+
@test maximum(abs.(sol_ml - params)) < 0.01
200+
@test maximum(abs.(sol_ml_sym - params)) < 0.01
201+
end
202+
164203
############################################################################################
165204
### test hessians
166205
############################################################################################
@@ -332,6 +371,47 @@ end
332371
)
333372
end
334373

374+
############################################################################################
375+
### data simulation
376+
############################################################################################
377+
378+
@testset "data_simulation_with_mean" begin
379+
# parameters to recover
380+
params = start_simple(
381+
model_ml;
382+
start_loadings = 0.5,
383+
start_regressions = 0.5,
384+
start_variances_observed = 0.5,
385+
start_variances_latent = 1.0,
386+
start_covariances_observed = 0.2,
387+
start_means = 0.5,
388+
)
389+
# set seed for simulation
390+
Random.seed!(83472834)
391+
colnames = Symbol.(names(example_data("political_democracy")))
392+
# simulate data
393+
model_ml_new = swap_observed(
394+
model_ml,
395+
data = rand(model_ml, params, 1_000_000),
396+
specification = spec,
397+
obs_colnames = colnames,
398+
meanstructure = true,
399+
)
400+
model_ml_sym_new = swap_observed(
401+
model_ml_sym,
402+
data = rand(model_ml_sym, params, 1_000_000),
403+
specification = spec,
404+
obs_colnames = colnames,
405+
meanstructure = true,
406+
)
407+
# fit models
408+
sol_ml = solution(sem_fit(model_ml_new))
409+
sol_ml_sym = solution(sem_fit(model_ml_sym_new))
410+
# check solution
411+
@test maximum(abs.(sol_ml - params)) < 0.01
412+
@test maximum(abs.(sol_ml_sym - params)) < 0.01
413+
end
414+
335415
############################################################################################
336416
### fiml
337417
############################################################################################

test/examples/recover_parameters/recover_parameters_twofact.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ imply_ml.Σ_function(imply_ml.Σ, true_val)
6060
true_dist = MultivariateNormal(imply_ml.Σ)
6161

6262
Random.seed!(1234)
63-
x = transpose(rand(true_dist, 100000))
63+
x = transpose(rand(true_dist, 100_000))
6464
semobserved = SemObservedData(data = x, specification = nothing)
6565

6666
loss_ml = SemLoss(SemML(; observed = semobserved, nparams = length(start)))

0 commit comments

Comments
 (0)