Skip to content

Commit 87fc28f

Browse files
committed
Less allocations
1 parent f7743e0 commit 87fc28f

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
122122
nothing,
123123
isdde = false,
124124
has_difference = false,
125+
split_parameters = false,
125126
kwargs...)
126127
if isdde
127128
eqs = delay_to_function(sys)
@@ -151,10 +152,12 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
151152
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
152153
states = sol_states,
153154
kwargs...)
155+
elseif split_parameters
156+
build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states,
157+
kwargs...)
154158
else
155-
fun = build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states,
159+
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156160
kwargs...)
157-
fun[1], :((out, u, p, t)->$(fun[2])(out, u, p..., t))
158161
end
159162
end
160163
end
@@ -326,15 +329,23 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
326329
checkbounds = false,
327330
sparsity = false,
328331
analytic = nothing,
332+
split_parameters = false,
329333
kwargs...) where {iip, specialize}
330334
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
331-
expression_module = eval_module, checkbounds = checkbounds,
335+
expression_module = eval_module, checkbounds = checkbounds, split_parameters,
332336
kwargs...)
333337
f_oop, f_iip = eval_expression ?
334338
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
335339
f_gen
336-
f(u, p, t) = f_oop(u, p, t)
337-
f(du, u, p, t) = f_iip(du, u, p, t)
340+
if split_parameters
341+
g(u, p, t) = f_oop(u, p..., t)
342+
g(du, u, p, t) = f_iip(du, u, p..., t)
343+
f = g
344+
else
345+
k(u, p, t) = f_oop(u, p, t)
346+
k(du, u, p, t) = f_iip(du, u, p, t)
347+
f = k
348+
end
338349

339350
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
340351
if u0 === nothing || p === nothing || t === nothing
@@ -687,6 +698,7 @@ function get_u0_p(sys,
687698
parammap;
688699
use_union = false,
689700
tofloat = !use_union,
701+
split_parameters = false,
690702
symbolic_u0 = false)
691703
eqs = equations(sys)
692704
dvs = states(sys)
@@ -699,7 +711,7 @@ function get_u0_p(sys,
699711
if symbolic_u0
700712
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
701713
else
702-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
714+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = !split_parameters)
703715
end
704716
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
705717
p = p === nothing ? SciMLBase.NullParameters() : p
@@ -717,15 +729,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
717729
use_union = false,
718730
tofloat = !use_union,
719731
symbolic_u0 = false,
732+
split_parameters = false,
720733
kwargs...)
721734
eqs = equations(sys)
722735
dvs = states(sys)
723736
ps = parameters(sys)
724737
iv = get_iv(sys)
725738

726-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
727-
split_ps, split_idxs = split_parameters_by_type(p)
728-
split_sym_ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
739+
u0, p, defs = get_u0_p(sys,
740+
u0map,
741+
parammap;
742+
tofloat,
743+
use_union,
744+
symbolic_u0,
745+
split_parameters)
746+
if split_parameters
747+
p, split_idxs = split_parameters_by_type(p)
748+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
749+
end
729750

730751
if implicit_dae && du0map !== nothing
731752
ddvs = map(Differential(iv), dvs)
@@ -739,12 +760,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
739760

740761
check_eqs_u0(eqs, dvs, u0; kwargs...)
741762

742-
f = constructor(sys, dvs, split_sym_ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
743-
checkbounds = checkbounds, p = split_ps,
763+
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
764+
checkbounds = checkbounds, p = p,
744765
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
745-
sparse = sparse, eval_expression = eval_expression,
766+
sparse = sparse, eval_expression = eval_expression, split_parameters,
746767
kwargs...)
747-
implicit_dae ? (f, du0, u0, split_ps) : (f, u0, split_ps)
768+
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
748769
end
749770

750771
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)

0 commit comments

Comments
 (0)