Skip to content

Commit d9884ae

Browse files
Alexey Stukalovalyst
authored andcommitted
matrix_gradient(): refactor
* rename from get_matrix_derivative(): it's not a getter, and gradient is a better term * construct sparse matrix directly, which is much more efficient * parameters arg is not needed
1 parent 115ccfb commit d9884ae

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

src/additional_functions/parameters.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,19 @@ function check_constants(M)
8686
return false
8787
end
8888

89-
function get_matrix_derivative(M_indices, parameters, n_long)
90-
∇M = [
91-
sparsevec(M_indices[i], ones(length(M_indices[i])), n_long) for
92-
i in 1:length(parameters)
93-
]
94-
95-
∇M = reduce(hcat, ∇M)
96-
97-
return ∇M
89+
# construct length(M)×length(parameters) sparse matrix of 1s at the positions,
90+
# where the corresponding parameter occurs in the M matrix
91+
function matrix_gradient(M_indices::ArrayParamsMap, M_length::Integer)
92+
rowval = reduce(vcat, M_indices)
93+
colptr =
94+
pushfirst!(accumulate((ptr, M_ind) -> ptr + length(M_ind), M_indices, init = 1), 1)
95+
return SparseMatrixCSC(
96+
M_length,
97+
length(M_indices),
98+
colptr,
99+
rowval,
100+
ones(length(rowval)),
101+
)
98102
end
99103

100104
# fill M with parameters

src/imply/RAM/generic.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ function RAM(;
158158
I_A = similar(A_pre)
159159

160160
if gradient
161-
∇A = get_matrix_derivative(A_indices, parameters, n_nod^2)
162-
∇S = get_matrix_derivative(S_indices, parameters, n_nod^2)
161+
∇A = matrix_gradient(A_indices, n_nod^2)
162+
∇S = matrix_gradient(S_indices, n_nod^2)
163163
else
164164
∇A = nothing
165165
∇S = nothing
@@ -168,15 +168,8 @@ function RAM(;
168168
# μ
169169
if meanstructure
170170
has_meanstructure = Val(true)
171-
172-
if gradient
173-
∇M = get_matrix_derivative(M_indices, parameters, n_nod)
174-
else
175-
∇M = nothing
176-
end
177-
171+
∇M = gradient ? matrix_gradient(M_indices, n_nod) : nothing
178172
μ = zeros(n_var)
179-
180173
else
181174
has_meanstructure = Val(false)
182175
M_indices = nothing

0 commit comments

Comments
 (0)