@@ -152,8 +152,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
152152 states = sol_states,
153153 kwargs... )
154154 else
155- build_function (rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
156- kwargs... )
155+ if p isa Tuple
156+ build_function (rhss, u, p... , t; postprocess_fbody = pre,
157+ states = sol_states,
158+ kwargs... )
159+ else
160+ build_function (rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
161+ kwargs... )
162+ end
157163 end
158164 end
159165end
@@ -332,8 +338,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
332338 f_oop, f_iip = eval_expression ?
333339 (drop_expr (@RuntimeGeneratedFunction (eval_module, ex)) for ex in f_gen) :
334340 f_gen
335- f (u, p, t) = f_oop (u, p, t)
336- f (du, u, p, t) = f_iip (du, u, p, t)
341+ if p isa Tuple
342+ g (u, p, t) = f_oop (u, p... , t)
343+ g (du, u, p, t) = f_iip (du, u, p... , t)
344+ f = g
345+ else
346+ k (u, p, t) = f_oop (u, p, t)
347+ k (du, u, p, t) = f_iip (du, u, p, t)
348+ f = k
349+ end
337350
338351 if specialize === SciMLBase. FunctionWrapperSpecialize && iip
339352 if u0 === nothing || p === nothing || t === nothing
@@ -384,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
384397
385398 obs = observed (sys)
386399 observedfun = if steady_state
387- let sys = sys, dict = Dict ()
400+ let sys = sys, dict = Dict (), ps = ps
388401 function generated_observed (obsvar, args... )
389402 obs = get! (dict, value (obsvar)) do
390403 build_explicit_observed_function (sys, obsvar)
391404 end
392405 if args === ()
393406 let obs = obs
394- (u, p, t = Inf ) -> obs (u, p, t)
407+ (u, p, t = Inf ) -> if ps isa Tuple
408+ obs (u, p... , t)
409+ else
410+ obs (u, p, t)
411+ end
395412 end
396413 else
397- length (args) == 2 ? obs (args... , Inf ) : obs (args... )
414+ if ps isa Tuple
415+ if length (args) == 2
416+ u, p = args
417+ obs (u, p... , Inf )
418+ else
419+ u, p, t = args
420+ obs (u, p... , t)
421+ end
422+ else
423+ if length (args) == 2
424+ u, p = args
425+ obs (u, p, Inf )
426+ else
427+ u, p, t = args
428+ obs (u, p, t)
429+ end
430+ end
398431 end
399432 end
400433 end
401434 else
402- let sys = sys, dict = Dict ()
435+ let sys = sys, dict = Dict (), ps = ps
403436 function generated_observed (obsvar, args... )
404437 obs = get! (dict, value (obsvar)) do
405- build_explicit_observed_function (sys, obsvar; checkbounds = checkbounds)
438+ build_explicit_observed_function (sys,
439+ obsvar;
440+ checkbounds = checkbounds,
441+ ps)
406442 end
407443 if args === ()
408444 let obs = obs
409- (u, p, t) -> obs (u, p, t)
445+ (u, p, t) -> if ps isa Tuple
446+ obs (u, p... , t)
447+ else
448+ obs (u, p, t)
449+ end
410450 end
411451 else
412- obs (args... )
452+ if ps isa Tuple # split parameters
453+ u, p, t = args
454+ obs (u, p... , t)
455+ else
456+ obs (args... )
457+ end
413458 end
414459 end
415460 end
@@ -677,15 +722,15 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
677722end
678723
679724"""
680- u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false , tofloat=!use_union )
725+ u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true , tofloat=true )
681726
682727Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
683728"""
684729function get_u0_p (sys,
685730 u0map,
686731 parammap;
687- use_union = false ,
688- tofloat = ! use_union ,
732+ use_union = true ,
733+ tofloat = true ,
689734 symbolic_u0 = false )
690735 dvs = states (sys)
691736 ps = parameters (sys)
@@ -712,16 +757,27 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
712757 simplify = false ,
713758 linenumbers = true , parallel = SerialForm (),
714759 eval_expression = true ,
715- use_union = false ,
716- tofloat = ! use_union ,
760+ use_union = true ,
761+ tofloat = true ,
717762 symbolic_u0 = false ,
718763 kwargs... )
719764 eqs = equations (sys)
720765 dvs = states (sys)
721766 ps = parameters (sys)
722767 iv = get_iv (sys)
723768
724- u0, p, defs = get_u0_p (sys, u0map, parammap; tofloat, use_union, symbolic_u0)
769+ u0, p, defs = get_u0_p (sys,
770+ u0map,
771+ parammap;
772+ tofloat,
773+ use_union,
774+ symbolic_u0)
775+
776+ p, split_idxs = split_parameters_by_type (p)
777+ if p isa Tuple
778+ ps = Base. Fix1 (getindex, parameters (sys)).(split_idxs)
779+ ps = (ps... ,) # if p is Tuple, ps should be Tuple
780+ end
725781
726782 if implicit_dae && du0map != = nothing
727783 ddvs = map (Differential (iv), dvs)
@@ -738,7 +794,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
738794 f = constructor (sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
739795 checkbounds = checkbounds, p = p,
740796 linenumbers = linenumbers, parallel = parallel, simplify = simplify,
741- sparse = sparse, eval_expression = eval_expression, kwargs... )
797+ sparse = sparse, eval_expression = eval_expression,
798+ kwargs... )
742799 implicit_dae ? (f, du0, u0, p) : (f, u0, p)
743800end
744801
0 commit comments