@@ -122,6 +122,7 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
122122 nothing ,
123123 isdde = false ,
124124 has_difference = false ,
125+ split_parameters = false ,
125126 kwargs... )
126127 if isdde
127128 eqs = delay_to_function (sys)
@@ -151,10 +152,12 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
151152 build_function (rhss, ddvs, u, p, t; postprocess_fbody = pre,
152153 states = sol_states,
153154 kwargs... )
155+ elseif split_parameters
156+ build_function (rhss, u, p... , t; postprocess_fbody = pre, states = sol_states,
157+ kwargs... )
154158 else
155- fun = build_function (rhss, u, p... , t; postprocess_fbody = pre, states = sol_states,
159+ build_function (rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156160 kwargs... )
157- fun[1 ], :((out, u, p, t)-> $ (fun[2 ])(out, u, p... , t))
158161 end
159162 end
160163end
@@ -326,15 +329,23 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
326329 checkbounds = false ,
327330 sparsity = false ,
328331 analytic = nothing ,
332+ split_parameters = false ,
329333 kwargs... ) where {iip, specialize}
330334 f_gen = generate_function (sys, dvs, ps; expression = Val{eval_expression},
331- expression_module = eval_module, checkbounds = checkbounds,
335+ expression_module = eval_module, checkbounds = checkbounds, split_parameters,
332336 kwargs... )
333337 f_oop, f_iip = eval_expression ?
334338 (drop_expr (@RuntimeGeneratedFunction (eval_module, ex)) for ex in f_gen) :
335339 f_gen
336- f (u, p, t) = f_oop (u, p, t)
337- f (du, u, p, t) = f_iip (du, u, p, t)
340+ if split_parameters
341+ g (u, p, t) = f_oop (u, p... , t)
342+ g (du, u, p, t) = f_iip (du, u, p... , t)
343+ f = g
344+ else
345+ k (u, p, t) = f_oop (u, p, t)
346+ k (du, u, p, t) = f_iip (du, u, p, t)
347+ f = k
348+ end
338349
339350 if specialize === SciMLBase. FunctionWrapperSpecialize && iip
340351 if u0 === nothing || p === nothing || t === nothing
@@ -687,6 +698,7 @@ function get_u0_p(sys,
687698 parammap;
688699 use_union = false ,
689700 tofloat = ! use_union,
701+ split_parameters = false ,
690702 symbolic_u0 = false )
691703 eqs = equations (sys)
692704 dvs = states (sys)
@@ -699,7 +711,7 @@ function get_u0_p(sys,
699711 if symbolic_u0
700712 u0 = varmap_to_vars (u0map, dvs; defaults = defs, tofloat = false , use_union = false )
701713 else
702- u0 = varmap_to_vars (u0map, dvs; defaults = defs, tofloat = true )
714+ u0 = varmap_to_vars (u0map, dvs; defaults = defs, tofloat = ! split_parameters )
703715 end
704716 p = varmap_to_vars (parammap, ps; defaults = defs, tofloat, use_union)
705717 p = p === nothing ? SciMLBase. NullParameters () : p
@@ -717,15 +729,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
717729 use_union = false ,
718730 tofloat = ! use_union,
719731 symbolic_u0 = false ,
732+ split_parameters = false ,
720733 kwargs... )
721734 eqs = equations (sys)
722735 dvs = states (sys)
723736 ps = parameters (sys)
724737 iv = get_iv (sys)
725738
726- u0, p, defs = get_u0_p (sys, u0map, parammap; tofloat, use_union, symbolic_u0)
727- split_ps, split_idxs = split_parameters_by_type (p)
728- split_sym_ps = Base. Fix1 (getindex, parameters (sys)).(split_idxs)
739+ u0, p, defs = get_u0_p (sys,
740+ u0map,
741+ parammap;
742+ tofloat,
743+ use_union,
744+ symbolic_u0,
745+ split_parameters)
746+ if split_parameters
747+ p, split_idxs = split_parameters_by_type (p)
748+ ps = Base. Fix1 (getindex, parameters (sys)).(split_idxs)
749+ end
729750
730751 if implicit_dae && du0map != = nothing
731752 ddvs = map (Differential (iv), dvs)
@@ -739,12 +760,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
739760
740761 check_eqs_u0 (eqs, dvs, u0; kwargs... )
741762
742- f = constructor (sys, dvs, split_sym_ps , u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
743- checkbounds = checkbounds, p = split_ps ,
763+ f = constructor (sys, dvs, ps , u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
764+ checkbounds = checkbounds, p = p ,
744765 linenumbers = linenumbers, parallel = parallel, simplify = simplify,
745- sparse = sparse, eval_expression = eval_expression,
766+ sparse = sparse, eval_expression = eval_expression, split_parameters,
746767 kwargs... )
747- implicit_dae ? (f, du0, u0, split_ps ) : (f, u0, split_ps )
768+ implicit_dae ? (f, du0, u0, p ) : (f, u0, p )
748769end
749770
750771function ODEFunctionExpr (sys:: AbstractODESystem , args... ; kwargs... )
0 commit comments