Skip to content

Commit 3160987

Browse files
code cleanup and docstring fixes
1 parent 8cbd504 commit 3160987

File tree

35 files changed

+134
-302
lines changed

35 files changed

+134
-302
lines changed

docs/src/tutorials/regularization/regularization.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,12 @@ It can be used as
3030
```julia
3131
SemOptimizerProximal(
3232
algorithm = ProximalAlgorithms.PANOC(),
33-
options = Dict{Symbol, Any}(),
3433
operator_g,
3534
operator_h = nothing
3635
)
3736
```
3837

39-
The proximal operator (aka the regularization function) can be passed as `operator_g`, available options are listed [here](https://juliafirstorder.github.io/ProximalOperators.jl/stable/functions/).
38+
The proximal operator (aka the regularization function) can be passed as `operator_g`.
4039
The available Algorithms are listed [here](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/).
4140

4241
## First example - lasso

src/additional_functions/helper.jl

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,84 +14,29 @@ function neumann_series(mat::SparseMatrixCSC; maxiter::Integer = size(mat, 1))
1414
return inverse
1515
end
1616

17-
#=
18-
function make_onelement_array(A)
19-
isa(A, Array) ? nothing : (A = [A])
20-
return A
21-
end
22-
=#
23-
24-
function semvec(observed, implied, loss, optimizer)
25-
observed = make_onelement_array(observed)
26-
implied = make_onelement_array(implied)
27-
loss = make_onelement_array(loss)
28-
optimizer = make_onelement_array(optimizer)
29-
30-
#sem_vec = Array{AbstractSem}(undef, maximum(length.([observed, implied, loss, optimizer])))
31-
sem_vec = Sem.(observed, implied, loss, optimizer)
32-
33-
return sem_vec
34-
end
35-
36-
skipmissing_mean(mat::AbstractMatrix) =
37-
[mean(skipmissing(coldata)) for coldata in eachcol(mat)]
38-
39-
function F_one_person(imp_mean, meandiff, inverse, data, logdet)
40-
F = logdet
41-
@. meandiff = data - imp_mean
42-
F += dot(meandiff, inverse, meandiff)
43-
return F
44-
end
45-
46-
function remove_all_missing(data::AbstractMatrix)
47-
keep = Vector{Int64}()
48-
for (i, coldata) in zip(axes(data, 1), eachrow(data))
49-
if any(!ismissing, coldata)
50-
push!(keep, i)
51-
end
52-
end
53-
return data[keep, :], keep
54-
end
55-
5617
function batch_inv!(fun, model)
5718
for i in 1:size(fun.inverses, 1)
5819
fun.inverses[i] .= LinearAlgebra.inv!(fun.choleskys[i])
5920
end
6021
end
6122

62-
#=
63-
function batch_sym_inv_update!(fun::Union{LossFunction, DiffFunction}, model)
64-
M_inv = inv(fun.choleskys[1])
65-
for i = 1:size(fun.inverses, 1)
66-
if size(model.observed.patterns_not[i]) == 0
67-
fun.inverses[i] .= M_inv
68-
else
69-
ind_not = model.observed.patterns_not[i]
70-
ind = model.observed.patterns[i]
71-
72-
A = M_inv[ind_not, ind]
73-
H = cholesky(M_inv[ind_not, ind_not])
74-
D = H \ A
75-
out = M_inv[ind, ind] - LinearAlgebra.BLAS.gemm('T', 'N', 1.0, A, D)
76-
fun.inverses[i] .= out
77-
end
78-
end
79-
end =#
80-
81-
function sparse_outer_mul!(C, A, B, ind) #computes A*S*B -> C, where ind gives the entries of S that are 1
23+
# computes A*S*B -> C, where ind gives the entries of S that are 1
24+
function sparse_outer_mul!(C, A, B, ind)
8225
fill!(C, 0.0)
8326
for i in 1:length(ind)
8427
BLAS.ger!(1.0, A[:, ind[i][1]], B[ind[i][2], :], C)
8528
end
8629
end
8730

88-
function sparse_outer_mul!(C, A, ind) #computes A*∇m, where ∇m ind gives the entries of ∇m that are 1
31+
# computes A*∇m, where ∇m ind gives the entries of ∇m that are 1
32+
function sparse_outer_mul!(C, A, ind)
8933
fill!(C, 0.0)
9034
@views C .= sum(A[:, ind], dims = 2)
9135
return C
9236
end
9337

94-
function sparse_outer_mul!(C, A, B::Vector, ind) #computes A*S*B -> C, where ind gives the entries of S that are 1
38+
# computes A*S*B -> C, where ind gives the entries of S that are 1
39+
function sparse_outer_mul!(C, A, B::Vector, ind)
9540
fill!(C, 0.0)
9641
@views @inbounds for i in 1:length(ind)
9742
C .+= B[ind[i][2]] .* A[:, ind[i][1]]

src/additional_functions/simulation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Return a new model with swaped observed part.
77
88
# Arguments
99
- `model::AbstractSemSingle`: model to swap the observed part of.
10-
- `kwargs`: additional keyword arguments; typically includes `data = ...`
10+
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
1111
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
1212
1313
# Examples

src/additional_functions/start_val/start_fabin3.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
# SemObservedMissing
1919
function start_fabin3(observed::SemObservedMissing, implied, args...; kwargs...)
2020
if !observed.em_model.fitted
21-
em_mvn(observed; kwargs...)
21+
em_mvn!(observed; kwargs...)
2222
end
2323

2424
return start_fabin3(implied.ram_matrices, observed.em_model.Σ, observed.em_model.μ)
@@ -45,24 +45,6 @@ function start_fabin3(
4545
)
4646
@assert length(F_var2obs) == size(F, 1)
4747

48-
# check in which matrix each parameter appears
49-
50-
#= in_S = length.(S_ind) .!= 0
51-
in_A = length.(A_ind) .!= 0
52-
A_ind_c = [linear2cartesian(ind, (n_var, n_var)) for ind in A_ind]
53-
in_Λ = [any(ind[2] .∈ F_ind) for ind in A_ind_c]
54-
55-
if !isnothing(M)
56-
in_M = length.(M_ind) .!= 0
57-
in_any = in_A .| in_S .| in_M
58-
else
59-
in_any = in_A .| in_S
60-
end
61-
62-
if !all(in_any)
63-
@warn "Could not determine fabin3 starting values for some parameters, default to 0."
64-
end =#
65-
6648
# set undirected parameters in S
6749
S_indices = CartesianIndices(S)
6850
for j in 1:nparams(S)
@@ -79,7 +61,6 @@ function start_fabin3(
7961

8062
# set loadings
8163
A_indices = CartesianIndices(A)
82-
# ind_Λ = findall([is_in_Λ(ind_vec, F_ind) for ind_vec in A_ind_c])
8364

8465
# collect latent variable indicators in A
8566
# maps latent parameter to the vector of dependent vars

src/additional_functions/start_val/start_simple.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ function start_simple(
6363
nparams(ram_matrices)
6464

6565
start_val = zeros(n_par)
66-
n_obs = nobserved_vars(ram_matrices)
6766
n_var = nvars(ram_matrices)
6867

6968
C_indices = CartesianIndices((n_var, n_var))

src/frontend/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# API methods supported by multiple SEM.jl types
22

33
"""
4-
params(semobj) -> Vector{Symbol}
4+
params(partable::ParameterTable) -> Vector{Symbol}
55
66
Return the vector of SEM model parameter identifiers.
77
"""
@@ -42,7 +42,7 @@ nlatent_vars(semobj) = length(latent_vars(semobj))
4242
"""
4343
param_indices(semobj)
4444
45-
Returns a dict of parameter names and their indices in `semobj`.
45+
Returns a dict of parameter labels and their indices in `semobj`.
4646
4747
# Examples
4848
```julia

src/frontend/fit/fitmeasures/fit_measures.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ end
1414
"""
1515
fit_measures(sem_fit, args...)
1616
17-
Return a default set of fit measures or the fit measures passed as `arg...`.
17+
Return a default set of fit measures or the fit measures passed as `args...`.
1818
"""
1919
function fit_measures end

src/frontend/fit/fitmeasures/minus2ll.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
function minus2ll(observed::SemObservedMissing)
4242
# fit EM-based mean and cov if not yet fitted
4343
# FIXME EM could be very computationally expensive
44-
observed.em_model.fitted || em_mvn(observed)
44+
observed.em_model.fitted || em_mvn!(observed)
4545

4646
Σ = observed.em_model.Σ
4747
μ = observed.em_model.μ

src/frontend/fit/summary.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ function details(sem_fit::SemFit; show_fitmeasures = false, color = :light_cyan,
1515
println("Number of data samples: $(nsamples(sem_fit))")
1616
print("\n")
1717
printstyled(
18-
"----------------------------------- Model ----------------------------------- \n";
18+
"----------------------------------- Model ------------------------------------ \n";
1919
color = color,
2020
)
2121
print("\n")
2222
print(sem_fit.model)
2323
print("\n")
2424
if show_fitmeasures
2525
printstyled(
26-
"--------------------------------- Fitmeasures --------------------------------- \n";
26+
"-------------------------------- Fitmeasures --------------------------------- \n";
2727
color = color,
2828
)
2929
print("\n")
@@ -51,7 +51,7 @@ function details(
5151
if show_variables
5252
print("\n")
5353
printstyled(
54-
"--------------------------------- Variables --------------------------------- \n";
54+
"---------------------------------- Variables --------------------------------- \n";
5555
color = color,
5656
)
5757
print("\n")
@@ -242,9 +242,6 @@ function details(
242242
)
243243
print("\n")
244244
end
245-
246-
#printstyled("""No need to copy and paste results, you can use CSV.write(DataFrame(my_partable), "myfile.csv")"""; hidden = true)
247-
248245
end
249246

250247
function details(
@@ -297,9 +294,6 @@ function details(
297294
show_columns = show_columns,
298295
)
299296
end
300-
301-
# printstyled("""No need to copy and paste results, you can use CSV.write(DataFrame(my_partable), "myfile.csv")"""; hidden = true)
302-
303297
end
304298

305299
function check_round(vec; digits)

src/frontend/pretty_printing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
##############################################################
2+
# Some helpers to implement show methods for SEM.jl objects
3+
##############################################################
4+
15
function print_field_types(io::IO, struct_instance)
26
fields = fieldnames(typeof(struct_instance))
37
types = [typeof(getproperty(struct_instance, field)) for field in fields]
@@ -25,7 +29,7 @@ function print_type(io::IO, struct_instance)
2529
end
2630

2731
##############################################################
28-
# Loss Functions, Implied,
32+
# Loss Function, Implied, Observed, Optimizer
2933
##############################################################
3034

3135
function Base.show(io::IO, struct_inst::SemLossFunction)

0 commit comments

Comments
 (0)