Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(DocStringExtensions.README)
"""
module ModelingToolkit
using PrecompileTools, Reexport
@recompile_invalidations begin
@recompile_invalidations begin
using DocStringExtensions
using Compat
using AbstractTrees
Expand Down
46 changes: 46 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,49 @@ macro parameters(xs...)
xs,
toparam) |> esc
end

function find_types(array)
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
# t = typeof(x)

get!(set, typeof(x)) do
# if t == Float64
# 1
# else
counter[] += 1
# end
end
end
end
return by.(array)
end

function split_parameters_by_type(ps)
if ps === SciMLBase.NullParameters()
return Float64[], [] #use Float64 to avoid Any type warning
else
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
get!(set, typeof(x)) do
counter[] += 1
end
end
end
idxs = by.(ps)
split_idxs = [Int[]]
for (i, idx) in enumerate(idxs)
if idx > length(split_idxs)
push!(split_idxs, Int[])
end
push!(split_idxs[idx], i)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
return (split_ps...,), split_idxs
end
end
end
6 changes: 3 additions & 3 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys,
tspan,
parammap = DiffEqBase.NullParameters();
callback = nothing,
use_union = false,
use_union = true,
tofloat = true,
check = true,
kwargs...) where {iip}
eqs = equations(sys)
Expand All @@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys,
defs = ModelingToolkit.mergedefaults(defs, parammap, ps)
defs = ModelingToolkit.mergedefaults(defs, u0map, dvs)
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union,
use_union)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)

has_difference = any(isdifferenceeq, eqs)
cbs = process_events(sys; callback, has_difference, kwargs...)
Expand Down
93 changes: 75 additions & 18 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
if p isa Tuple
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
end
end
end
end
Expand Down Expand Up @@ -332,8 +338,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
if p isa Tuple
g(u, p, t) = f_oop(u, p..., t)
g(du, u, p, t) = f_iip(du, u, p..., t)
f = g
else
k(u, p, t) = f_oop(u, p, t)
k(du, u, p, t) = f_iip(du, u, p, t)
f = k
end

if specialize === SciMLBase.FunctionWrapperSpecialize && iip
if u0 === nothing || p === nothing || t === nothing
Expand Down Expand Up @@ -384,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s

obs = observed(sys)
observedfun = if steady_state
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
if args === ()
let obs = obs
(u, p, t = Inf) -> obs(u, p, t)
(u, p, t = Inf) -> if ps isa Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
end
else
length(args) == 2 ? obs(args..., Inf) : obs(args...)
if ps isa Tuple
if length(args) == 2
u, p = args
obs(u, p..., Inf)
else
u, p, t = args
obs(u, p..., t)
end
else
if length(args) == 2
u, p = args
obs(u, p, Inf)
else
u, p, t = args
obs(u, p, t)
end
end
end
end
end
else
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
build_explicit_observed_function(sys,
obsvar;
checkbounds = checkbounds,
ps)
end
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
(u, p, t) -> if ps isa Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
end
else
obs(args...)
if ps isa Tuple # split parameters
u, p, t = args
obs(u, p..., t)
else
obs(args...)
end
end
end
end
Expand Down Expand Up @@ -677,15 +722,15 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
end

"""
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)

Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
"""
function get_u0_p(sys,
u0map,
parammap;
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)
Expand All @@ -712,16 +757,27 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
u0, p, defs = get_u0_p(sys,
u0map,
parammap;
tofloat,
use_union,
symbolic_u0)

p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
end

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand All @@ -738,7 +794,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression, kwargs...)
sparse = sparse, eval_expression = eval_expression,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end

Expand Down
14 changes: 9 additions & 5 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts;
output_type = Array,
checkbounds = true,
drop_expr = drop_expr,
ps = parameters(sys),
throw = true)
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
Expand Down Expand Up @@ -385,17 +386,20 @@ function build_explicit_observed_function(sys, ts;
push!(obsexprs, lhs ← rhs)
end

pars = parameters(sys)
if inputs !== nothing
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
else
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
end
ps = DestructuredArgs(pars, inbounds = !checkbounds)
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
if inputs === nothing
args = [dvs, ps, ivs...]
args = [dvs, ps..., ivs...]
else
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
args = [dvs, ipts, ps, ivs...]
args = [dvs, ipts, ps..., ivs...]
end
pre = get_postprocess_fbody(sys)

Expand Down
18 changes: 17 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ end

hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
function getdefaulttype(v)
def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing))
def === nothing ? Float64 : typeof(def)
end
function setdefault(v, val)
val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
end
Expand Down Expand Up @@ -642,10 +646,15 @@ end
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
end

function promote_to_concrete(vs; tofloat = true, use_union = false)
function promote_to_concrete(vs; tofloat = true, use_union = true)
if isempty(vs)
return vs
end
if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array
tofloat = false
use_union = true
vs = Any[vs...]
end
T = eltype(vs)
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
vs
Expand All @@ -656,6 +665,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
I = Int8
has_int = false
has_array = false
has_bool = false
array_T = nothing
for v in vs
if v isa AbstractArray
Expand All @@ -668,6 +678,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
has_int = true
I = promote_type(I, E)
end
if E <: Bool
has_bool = true
end
end
if tofloat && !has_array
C = float(C)
Expand All @@ -678,6 +691,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
if has_int
C = Union{C, I}
end
if has_bool
C = Union{C, Bool}
end
return copyto!(similar(vs, C), vs)
end
convert.(C, vs)
Expand Down
9 changes: 5 additions & 4 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
tofloat = true, use_union = true)
varlist = collect(map(unwrap, varlist))

# Edge cases where one of the arguments is effectively empty.
Expand All @@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

T = typeof(varmap)
# We respect the input type
container_type = T <: Dict ? Array : T
# T = typeof(varmap)
# We respect the input type (feature removed, not needed with Tuple support)
# container_type = T <: Union{Dict,Tuple} ? Array : T
container_type = Array

vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
varmap = todict(varmap)
Expand Down
20 changes: 15 additions & 5 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,28 @@ let
u0map = [A => 1.0]
pmap = (k1 => 1.0, k2 => 1)
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
@test prob.p isa Vector{Float64}

pmap = [k1 => 1, k2 => 1]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap)
@test eltype(prob.p) === Float64

pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
@test eltype(prob.p) === Union{Float64, Int}
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test eltype(prob.p) === Int

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p isa Vector{Float64}

# No longer supported, Tuple used instead
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
# tspan = (0.0, 1.0)
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
# @test eltype(prob.p) === Union{Float64, Int}
end

let
Expand Down
Loading