From f7743e085fd190000cf3db23a68255722d8af4e2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 15 Aug 2023 17:24:45 -0400 Subject: [PATCH 01/15] WIP: tuple parameters --- src/parameters.jl | 26 ++++++++++++++++++++++++ src/systems/diffeqs/abstractodesystem.jl | 14 ++++++++----- src/utils.jl | 4 ++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/parameters.jl b/src/parameters.jl index 9174ac454f..598f40f1d6 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -61,3 +61,29 @@ macro parameters(xs...) xs, toparam) |> esc end + +function split_parameters_by_type(ps) + by = let set = Dict{Any, Int}(), counter = Ref(1) + x -> begin + t = typeof(x) + get!(set, typeof(x)) do + if t == Float64 + 1 + else + counter[] += 1 + end + end + end + end + idxs = by.(ps) + split_idxs = [Int[]] + for (i, idx) in enumerate(idxs) + if idx > length(split_idxs) + push!(split_idxs, Int[]) + end + push!(split_idxs[idx], i) + end + tighten_types = x -> identity.(x) + split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) + (split_ps...,), split_idxs +end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 30b237c807..def4f212af 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -152,8 +152,9 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param states = sol_states, kwargs...) else - build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, + fun = build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, kwargs...) + fun[1], :((out, u, p, t)->$(fun[2])(out, u, p..., t)) end end end @@ -723,6 +724,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; iv = get_iv(sys) u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0) + split_ps, split_idxs = split_parameters_by_type(p) + split_sym_ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) @@ -736,11 +739,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; check_eqs_u0(eqs, dvs, u0; kwargs...) - f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, - checkbounds = checkbounds, p = p, + f = constructor(sys, dvs, split_sym_ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, + checkbounds = checkbounds, p = split_ps, linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, kwargs...) - implicit_dae ? (f, du0, u0, p) : (f, u0, p) + sparse = sparse, eval_expression = eval_expression, + kwargs...) + implicit_dae ? (f, du0, u0, split_ps) : (f, u0, split_ps) end function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index 4dc2a636df..80059d211b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -243,6 +243,10 @@ end hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue) getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue)) +function getdefaulttype(v) + def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing)) + def === nothing ? Float64 : typeof(def) +end function setdefault(v, val) val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val)) end From 87fc28f6abba79a9be82375748e89f3e5abb66e3 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 15 Aug 2023 18:05:06 -0400 Subject: [PATCH 02/15] Less allocations --- src/systems/diffeqs/abstractodesystem.jl | 47 +++++++++++++++++------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index def4f212af..6b3421d59a 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -122,6 +122,7 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param nothing, isdde = false, has_difference = false, + split_parameters = false, kwargs...) if isdde eqs = delay_to_function(sys) @@ -151,10 +152,12 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states, kwargs...) + elseif split_parameters + build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, + kwargs...) else - fun = build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, + build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, kwargs...) - fun[1], :((out, u, p, t)->$(fun[2])(out, u, p..., t)) end end end @@ -326,15 +329,23 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s checkbounds = false, sparsity = false, analytic = nothing, + split_parameters = false, kwargs...) where {iip, specialize} f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, - expression_module = eval_module, checkbounds = checkbounds, + expression_module = eval_module, checkbounds = checkbounds, split_parameters, kwargs...) f_oop, f_iip = eval_expression ? (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) : f_gen - f(u, p, t) = f_oop(u, p, t) - f(du, u, p, t) = f_iip(du, u, p, t) + if split_parameters + g(u, p, t) = f_oop(u, p..., t) + g(du, u, p, t) = f_iip(du, u, p..., t) + f = g + else + k(u, p, t) = f_oop(u, p, t) + k(du, u, p, t) = f_iip(du, u, p, t) + f = k + end if specialize === SciMLBase.FunctionWrapperSpecialize && iip if u0 === nothing || p === nothing || t === nothing @@ -687,6 +698,7 @@ function get_u0_p(sys, parammap; use_union = false, tofloat = !use_union, + split_parameters = false, symbolic_u0 = false) eqs = equations(sys) dvs = states(sys) @@ -699,7 +711,7 @@ function get_u0_p(sys, if symbolic_u0 u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) else - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = !split_parameters) end p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) p = p === nothing ? SciMLBase.NullParameters() : p @@ -717,15 +729,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; use_union = false, tofloat = !use_union, symbolic_u0 = false, + split_parameters = false, kwargs...) eqs = equations(sys) dvs = states(sys) ps = parameters(sys) iv = get_iv(sys) - u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0) - split_ps, split_idxs = split_parameters_by_type(p) - split_sym_ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + u0, p, defs = get_u0_p(sys, + u0map, + parammap; + tofloat, + use_union, + symbolic_u0, + split_parameters) + if split_parameters + p, split_idxs = split_parameters_by_type(p) + ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + end if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) @@ -739,12 +760,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; check_eqs_u0(eqs, dvs, u0; kwargs...) - f = constructor(sys, dvs, split_sym_ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, - checkbounds = checkbounds, p = split_ps, + f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, + checkbounds = checkbounds, p = p, linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, + sparse = sparse, eval_expression = eval_expression, split_parameters, kwargs...) - implicit_dae ? (f, du0, u0, split_ps) : (f, u0, split_ps) + implicit_dae ? (f, du0, u0, p) : (f, u0, p) end function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...) From f9b0fa96e662defdd43179d522f7f28164b67ba2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 15 Aug 2023 18:09:39 -0400 Subject: [PATCH 03/15] Add split parameters tests Needs https://github.com/SciML/ModelingToolkitStandardLibrary.jl/pull/211 --- test/runtests.jl | 1 + test/split_parameters.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 test/split_parameters.jl diff --git a/test/runtests.jl b/test/runtests.jl index b3c017006d..bf68341b53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ using SafeTestsets, Test @safetestset "JumpSystem Test" include("jumpsystem.jl") @safetestset "Constraints Test" include("constraints.jl") @safetestset "Reduction Test" include("reduction.jl") +@safetestset "Split Parameters Test" include("split_parameters.jl") @safetestset "ODAEProblem Test" include("odaeproblem.jl") @safetestset "Components Test" include("components.jl") @safetestset "Model Parsing Test" include("model_parsing.jl") diff --git a/test/split_parameters.jl b/test/split_parameters.jl new file mode 100644 index 0000000000..194a1e92c4 --- /dev/null +++ b/test/split_parameters.jl @@ -0,0 +1,31 @@ +using ModelingToolkit, Test +using ModelingToolkitStandardLibrary.Blocks +using OrdinaryDiffEq + +dt = 4e-4 +t_end = 10.0 +time = 0:dt:t_end +x = @. time^2 + 1.0 + +@parameters t +D = Differential(t) + +vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 +@named src = SampledData(; data = Float64[], dt) +@named int = Integrator() + +eqs = [y ~ src.output.u + D(y) ~ dy + D(dy) ~ ddy + connect(src.output, int.input)] + +@named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) +s = complete(sys) +sys = structural_simplify(sys) + +prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; split_parameters = true) +@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} +@time sol = solve(prob, ImplicitEuler()); +prob2 = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; to_float = false) +@test prob2.p isa Vector{Union{Float64, Int64, Vector{Float64}}} +@time sol2 = solve(prob2, ImplicitEuler()); From d20f68d5d411ec149eeed2c7472dd81e31ea9171 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Fri, 18 Aug 2023 17:08:02 -0400 Subject: [PATCH 04/15] working default Tuple type --- src/parameters.jl | 56 +++++++++++++------ src/systems/diffeqs/abstractodesystem.jl | 44 ++++++++------- src/variables.jl | 9 +-- test/odesystem.jl | 19 ++++--- test/split_parameters.jl | 70 +++++++++++++++++++++--- 5 files changed, 142 insertions(+), 56 deletions(-) diff --git a/src/parameters.jl b/src/parameters.jl index 598f40f1d6..a64e4262fe 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -62,28 +62,50 @@ macro parameters(xs...) toparam) |> esc end -function split_parameters_by_type(ps) - by = let set = Dict{Any, Int}(), counter = Ref(1) +function find_types(array) + by = let set = Dict{Any, Int}(), counter = Ref(0) x -> begin - t = typeof(x) + # t = typeof(x) + get!(set, typeof(x)) do - if t == Float64 - 1 - else - counter[] += 1 - end + # if t == Float64 + # 1 + # else + counter[] += 1 + # end end end end - idxs = by.(ps) - split_idxs = [Int[]] - for (i, idx) in enumerate(idxs) - if idx > length(split_idxs) - push!(split_idxs, Int[]) + return by.(array) +end + + +function split_parameters_by_type(ps) + + if ps === SciMLBase.NullParameters() + return Float64[],[] #use Float64 to avoid Any type warning + else + by = let set = Dict{Any, Int}(), counter = Ref(0) + x -> begin + get!(set, typeof(x)) do + counter[] += 1 + end + end + end + idxs = by.(ps) + split_idxs = [Int[]] + for (i, idx) in enumerate(idxs) + if idx > length(split_idxs) + push!(split_idxs, Int[]) + end + push!(split_idxs[idx], i) + end + tighten_types = x -> identity.(x) + split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) + if length(split_ps) == 1 #Tuple not needed, only 1 type + return split_ps[1], split_idxs + else + return (split_ps...,), split_idxs end - push!(split_idxs[idx], i) end - tighten_types = x -> identity.(x) - split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) - (split_ps...,), split_idxs end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 6b3421d59a..3c7aa0f236 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -122,7 +122,6 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param nothing, isdde = false, has_difference = false, - split_parameters = false, kwargs...) if isdde eqs = delay_to_function(sys) @@ -152,12 +151,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states, kwargs...) - elseif split_parameters - build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, - kwargs...) else - build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, + if p isa Tuple + build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, + kwargs...) + else + build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, kwargs...) + end end end end @@ -329,15 +330,14 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s checkbounds = false, sparsity = false, analytic = nothing, - split_parameters = false, kwargs...) where {iip, specialize} f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, - expression_module = eval_module, checkbounds = checkbounds, split_parameters, + expression_module = eval_module, checkbounds = checkbounds, kwargs...) f_oop, f_iip = eval_expression ? (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) : f_gen - if split_parameters + if p isa Tuple g(u, p, t) = f_oop(u, p..., t) g(du, u, p, t) = f_iip(du, u, p..., t) f = g @@ -696,9 +696,8 @@ Take dictionaries with initial conditions and parameters and convert them to num function get_u0_p(sys, u0map, parammap; - use_union = false, - tofloat = !use_union, - split_parameters = false, + use_union = true, + tofloat = true, symbolic_u0 = false) eqs = equations(sys) dvs = states(sys) @@ -711,7 +710,7 @@ function get_u0_p(sys, if symbolic_u0 u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) else - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = !split_parameters) + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) end p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) p = p === nothing ? SciMLBase.NullParameters() : p @@ -726,10 +725,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; simplify = false, linenumbers = true, parallel = SerialForm(), eval_expression = true, - use_union = false, - tofloat = !use_union, + use_union = true, + tofloat = true, symbolic_u0 = false, - split_parameters = false, + # split_parameters = true, kwargs...) eqs = equations(sys) dvs = states(sys) @@ -741,12 +740,15 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; parammap; tofloat, use_union, - symbolic_u0, - split_parameters) - if split_parameters + symbolic_u0) + + # if split_parameters p, split_idxs = split_parameters_by_type(p) - ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) - end + if p isa Tuple + ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + ps = (ps...,) #if p is Tuple, ps should be Tuple + end + # end if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) @@ -763,7 +765,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, checkbounds = checkbounds, p = p, linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, split_parameters, + sparse = sparse, eval_expression = eval_expression, kwargs...) implicit_dae ? (f, du0, u0, p) : (f, u0, p) end diff --git a/src/variables.jl b/src/variables.jl index 4d11193462..9af5b1f6c4 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -58,7 +58,7 @@ applicable. """ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, toterm = default_toterm, promotetoconcrete = nothing, - tofloat = true, use_union = false) + tofloat = true, use_union = true) varlist = collect(map(unwrap, varlist)) # Edge cases where one of the arguments is effectively empty. @@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, end end - T = typeof(varmap) - # We respect the input type - container_type = T <: Dict ? Array : T + # T = typeof(varmap) + # We respect the input type (feature removed, not needed with Tuple support) + # container_type = T <: Union{Dict,Tuple} ? Array : T + container_type = Array vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs varmap = todict(varmap) diff --git a/test/odesystem.jl b/test/odesystem.jl index a45a45b1b0..e6d168197f 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -734,18 +734,23 @@ let u0map = [A => 1.0] pmap = (k1 => 1.0, k2 => 1) tspan = (0.0, 1.0) - prob = ODEProblem(sys, u0map, tspan, pmap) - @test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))]) + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false) + + @test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))]) pmap = [k1 => 1, k2 => 1] tspan = (0.0, 1.0) prob = ODEProblem(sys, u0map, tspan, pmap) @test eltype(prob.p) === Float64 - - pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0] - tspan = (0.0, 1.0) - prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true) - @test eltype(prob.p) === Union{Float64, Int} + + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false) + @test eltype(prob.p) === Int + + # No longer supported, Tuple used instead + # pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0] + # tspan = (0.0, 1.0) + # prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true) + # @test eltype(prob.p) === Union{Float64, Int} end let diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 194a1e92c4..09be7d55c0 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -2,6 +2,10 @@ using ModelingToolkit, Test using ModelingToolkitStandardLibrary.Blocks using OrdinaryDiffEq + + +# ------------------------ Mixed Single Values and Vector + dt = 4e-4 t_end = 10.0 time = 0:dt:t_end @@ -10,8 +14,31 @@ x = @. time^2 + 1.0 @parameters t D = Differential(t) +get_value(data, t, dt) = data[round(Int, t/dt+1)] +@register_symbolic get_value(data, t, dt) + + +function Sampled(; name, data=Float64[], dt=0.0) + pars = @parameters begin + data = data + dt = dt + end + + vars = [] + systems = @named begin + output = RealOutput() + end + + eqs = [ + output.u ~ get_value(data, t, dt) + ] + + return ODESystem(eqs, t, vars, pars; name, systems, + defaults = [output.u => data[1]]) +end + vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 -@named src = SampledData(; data = Float64[], dt) +@named src = Sampled(; data = Float64[], dt) @named int = Integrator() eqs = [y ~ src.output.u @@ -22,10 +49,39 @@ eqs = [y ~ src.output.u @named sys = ODESystem(eqs, t, vars, []; systems = [int, src]) s = complete(sys) sys = structural_simplify(sys) - -prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; split_parameters = true) +prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]) @test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} -@time sol = solve(prob, ImplicitEuler()); -prob2 = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; to_float = false) -@test prob2.p isa Vector{Union{Float64, Int64, Vector{Float64}}} -@time sol2 = solve(prob2, ImplicitEuler()); +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success + + +# ------------------------ Mixed Type Converted to float (default behavior) + +vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 +pars = @parameters a=1.0 b=2.0 c=3 +eqs = [ + D(y) ~ dy*a + D(dy) ~ ddy*b + ddy ~ sin(t)*c] + +@named sys = ODESystem(eqs, t, vars, pars) +sys = structural_simplify(sys) + +tspan = (0.0, t_end) +prob = ODEProblem(sys, [], tspan, []) + +@test prob.p isa Vector{Float64} +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success + + +# ------------------------ Mixed Type Conserved + +prob = ODEProblem(sys, [], tspan, []; tofloat=false) + +@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} +sol = solve(prob, ImplicitEuler()); +@test sol.retcode == ReturnCode.Success + + + From 10ccf8ab03fac7225fd05bf6443208523efe5a85 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Fri, 18 Aug 2023 17:38:18 -0400 Subject: [PATCH 05/15] format --- src/parameters.jl | 12 +++++------- src/structural_transformation/codegen.jl | 6 +++--- src/systems/diffeqs/abstractodesystem.jl | 21 +++++++++++---------- src/utils.jl | 2 +- test/odesystem.jl | 13 +++++++++---- test/split_parameters.jl | 23 +++++++---------------- 6 files changed, 36 insertions(+), 41 deletions(-) diff --git a/src/parameters.jl b/src/parameters.jl index a64e4262fe..4339cb7acf 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -66,7 +66,7 @@ function find_types(array) by = let set = Dict{Any, Int}(), counter = Ref(0) x -> begin # t = typeof(x) - + get!(set, typeof(x)) do # if t == Float64 # 1 @@ -79,16 +79,14 @@ function find_types(array) return by.(array) end - function split_parameters_by_type(ps) - if ps === SciMLBase.NullParameters() - return Float64[],[] #use Float64 to avoid Any type warning + return Float64[], [] #use Float64 to avoid Any type warning else by = let set = Dict{Any, Int}(), counter = Ref(0) - x -> begin + x -> begin get!(set, typeof(x)) do - counter[] += 1 + counter[] += 1 end end end @@ -103,7 +101,7 @@ function split_parameters_by_type(ps) tighten_types = x -> identity.(x) split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) if length(split_ps) == 1 #Tuple not needed, only 1 type - return split_ps[1], split_idxs + return split_ps[1], split_idxs else return (split_ps...,), split_idxs end diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 65939859c9..c6a957a6e5 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys, tspan, parammap = DiffEqBase.NullParameters(); callback = nothing, - use_union = false, + use_union = true, + tofloat = true, check = true, kwargs...) where {iip} eqs = equations(sys) @@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys, defs = ModelingToolkit.mergedefaults(defs, parammap, ps) defs = ModelingToolkit.mergedefaults(defs, u0map, dvs) u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) - p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union, - use_union) + p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) has_difference = any(isdifferenceeq, eqs) cbs = process_events(sys; callback, has_difference, kwargs...) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 3c7aa0f236..78e2d30a29 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -153,11 +153,12 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param kwargs...) else if p isa Tuple - build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, + build_function(rhss, u, p..., t; postprocess_fbody = pre, + states = sol_states, kwargs...) else build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, - kwargs...) + kwargs...) end end end @@ -332,7 +333,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s analytic = nothing, kwargs...) where {iip, specialize} f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, - expression_module = eval_module, checkbounds = checkbounds, + expression_module = eval_module, checkbounds = checkbounds, kwargs...) f_oop, f_iip = eval_expression ? (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) : @@ -689,7 +690,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys), end """ - u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union) + u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true) Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. """ @@ -743,11 +744,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; symbolic_u0) # if split_parameters - p, split_idxs = split_parameters_by_type(p) - if p isa Tuple - ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) - ps = (ps...,) #if p is Tuple, ps should be Tuple - end + p, split_idxs = split_parameters_by_type(p) + if p isa Tuple + ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) + ps = (ps...,) #if p is Tuple, ps should be Tuple + end # end if implicit_dae && du0map !== nothing @@ -765,7 +766,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, checkbounds = checkbounds, p = p, linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, + sparse = sparse, eval_expression = eval_expression, kwargs...) implicit_dae ? (f, du0, u0, p) : (f, u0, p) end diff --git a/src/utils.jl b/src/utils.jl index 80059d211b..cf16c2bf61 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -670,7 +670,7 @@ end throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list.")) end -function promote_to_concrete(vs; tofloat = true, use_union = false) +function promote_to_concrete(vs; tofloat = true, use_union = true) if isempty(vs) return vs end diff --git a/test/odesystem.jl b/test/odesystem.jl index e6d168197f..caa68eab4c 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -734,18 +734,23 @@ let u0map = [A => 1.0] pmap = (k1 => 1.0, k2 => 1) tspan = (0.0, 1.0) - prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false) - + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false) @test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))]) + prob = ODEProblem(sys, u0map, tspan, pmap) + @test prob.p isa Vector{Float64} + pmap = [k1 => 1, k2 => 1] tspan = (0.0, 1.0) prob = ODEProblem(sys, u0map, tspan, pmap) @test eltype(prob.p) === Float64 - - prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false) + + prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false) @test eltype(prob.p) === Int + prob = ODEProblem(sys, u0map, tspan, pmap) + @test prob.p isa Vector{Float64} + # No longer supported, Tuple used instead # pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0] # tspan = (0.0, 1.0) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 09be7d55c0..c144bfa7c7 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -2,8 +2,6 @@ using ModelingToolkit, Test using ModelingToolkitStandardLibrary.Blocks using OrdinaryDiffEq - - # ------------------------ Mixed Single Values and Vector dt = 4e-4 @@ -14,11 +12,10 @@ x = @. time^2 + 1.0 @parameters t D = Differential(t) -get_value(data, t, dt) = data[round(Int, t/dt+1)] +get_value(data, t, dt) = data[round(Int, t / dt + 1)] @register_symbolic get_value(data, t, dt) - -function Sampled(; name, data=Float64[], dt=0.0) +function Sampled(; name, data = Float64[], dt = 0.0) pars = @parameters begin data = data dt = dt @@ -30,7 +27,7 @@ function Sampled(; name, data=Float64[], dt=0.0) end eqs = [ - output.u ~ get_value(data, t, dt) + output.u ~ get_value(data, t, dt), ] return ODESystem(eqs, t, vars, pars; name, systems, @@ -54,15 +51,13 @@ prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]) sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success - # ------------------------ Mixed Type Converted to float (default behavior) vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 pars = @parameters a=1.0 b=2.0 c=3 -eqs = [ - D(y) ~ dy*a - D(dy) ~ ddy*b - ddy ~ sin(t)*c] +eqs = [D(y) ~ dy * a + D(dy) ~ ddy * b + ddy ~ sin(t) * c] @named sys = ODESystem(eqs, t, vars, pars) sys = structural_simplify(sys) @@ -74,14 +69,10 @@ prob = ODEProblem(sys, [], tspan, []) sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success - # ------------------------ Mixed Type Conserved -prob = ODEProblem(sys, [], tspan, []; tofloat=false) +prob = ODEProblem(sys, [], tspan, []; tofloat = false) @test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success - - - From 9998635a3cefd50dcede319858945f455f1c8458 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Fri, 18 Aug 2023 17:46:40 -0400 Subject: [PATCH 06/15] clean up --- src/systems/diffeqs/abstractodesystem.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 78e2d30a29..e41278fc0c 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -729,7 +729,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; use_union = true, tofloat = true, symbolic_u0 = false, - # split_parameters = true, kwargs...) eqs = equations(sys) dvs = states(sys) @@ -743,13 +742,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; use_union, symbolic_u0) - # if split_parameters p, split_idxs = split_parameters_by_type(p) if p isa Tuple ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) ps = (ps...,) #if p is Tuple, ps should be Tuple end - # end if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) From 84b66a70c6bb2a41f246419b0fe824e67397bb28 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Sat, 19 Aug 2023 06:58:54 -0400 Subject: [PATCH 07/15] add observables test --- test/split_parameters.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index c144bfa7c7..ba4896f8b1 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -76,3 +76,23 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false) @test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success + + + +# ------------------------- Observables + +@named c = Sine(; frequency = 1) +@named absb = Abs(;) +@named int = Integrator(; k = 1) +@named model = ODESystem([ + connect(c.output, absb.input), + connect(absb.output, int.input), + ], + t, + systems = [int, absb, c]) +sys = structural_simplify(model) +prob = ODEProblem(sys, Pair[int.x => 0.0], (0.0, 1.0)) +sol = solve(prob, Rodas4()) +@test isequal(unbound_inputs(sys), []) +@test sol.retcode == Success +sol[absb.output.u] \ No newline at end of file From 510833c686e501457c7abd1230476524303a0a2f Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Thu, 7 Sep 2023 16:57:24 -0400 Subject: [PATCH 08/15] fixed split_parameters.jl test --- test/split_parameters.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index ba4896f8b1..c03749a473 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -93,6 +93,5 @@ sol = solve(prob, ImplicitEuler()); sys = structural_simplify(model) prob = ODEProblem(sys, Pair[int.x => 0.0], (0.0, 1.0)) sol = solve(prob, Rodas4()) -@test isequal(unbound_inputs(sys), []) -@test sol.retcode == Success +@test sol.retcode == ReturnCode.Success sol[absb.output.u] \ No newline at end of file From d9a401f477a1eb9f1b901ae17e3b38df2791ffaf Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 8 Sep 2023 14:13:47 -0400 Subject: [PATCH 09/15] Support observed function building for split parameters --- src/systems/diffeqs/abstractodesystem.jl | 46 ++++++++++++++++++++---- src/systems/diffeqs/odesystem.jl | 11 ++++-- test/split_parameters.jl | 9 +++-- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 6ce72d8d88..ebc612a774 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -397,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s obs = observed(sys) observedfun = if steady_state - let sys = sys, dict = Dict() + let sys = sys, dict = Dict(), ps = ps function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do build_explicit_observed_function(sys, obsvar) end if args === () let obs = obs - (u, p, t = Inf) -> obs(u, p, t) + (u, p, t = Inf) -> if ps isa Tuple + obs(u, p..., t) + else + obs(u, p, t) + end end else - length(args) == 2 ? obs(args..., Inf) : obs(args...) + if ps isa Tuple + if length(args) == 2 + u, p = args + obs(u, p..., Inf) + else + u, p, t = args + obs(u, p..., t) + end + else + if length(args) == 2 + u, p = args + obs(u, p, Inf) + else + u, p, t = args + obs(u, p, t) + end + end end end end else - let sys = sys, dict = Dict() + let sys = sys, dict = Dict(), ps = ps function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) + build_explicit_observed_function(sys, + obsvar; + checkbounds = checkbounds, + ps) end if args === () let obs = obs - (u, p, t) -> obs(u, p, t) + (u, p, t) -> if ps isa Tuple + obs(u, p..., t) + else + obs(u, p, t) + end end else - obs(args...) + if ps isa Tuple # split parameters + u, p, t = args + obs(u, p..., t) + else + obs(args...) + end end end end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e368df4f1b..559103b91a 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, drop_expr = drop_expr, + ps = paramteres(sys), throw = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] @@ -389,13 +390,17 @@ function build_explicit_observed_function(sys, ts; if inputs !== nothing pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list end - ps = DestructuredArgs(pars, inbounds = !checkbounds) + if ps isa Tuple + ps = DestructuredArgs.(ps, inbounds = !checkbounds) + else + ps = (DestructuredArgs(ps, inbounds = !checkbounds),) + end dvs = DestructuredArgs(states(sys), inbounds = !checkbounds) if inputs === nothing - args = [dvs, ps, ivs...] + args = [dvs, ps..., ivs...] else ipts = DestructuredArgs(inputs, inbounds = !checkbounds) - args = [dvs, ipts, ps, ivs...] + args = [dvs, ipts, ps..., ivs...] end pre = get_postprocess_fbody(sys) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index c03749a473..f457e9e22e 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -59,8 +59,8 @@ eqs = [D(y) ~ dy * a D(dy) ~ ddy * b ddy ~ sin(t) * c] -@named sys = ODESystem(eqs, t, vars, pars) -sys = structural_simplify(sys) +@named model = ODESystem(eqs, t, vars, pars) +sys = structural_simplify(model) tspan = (0.0, t_end) prob = ODEProblem(sys, [], tspan, []) @@ -76,8 +76,7 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false) @test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success - - +sol[states(model)] # ------------------------- Observables @@ -94,4 +93,4 @@ sys = structural_simplify(model) prob = ODEProblem(sys, Pair[int.x => 0.0], (0.0, 1.0)) sol = solve(prob, Rodas4()) @test sol.retcode == ReturnCode.Success -sol[absb.output.u] \ No newline at end of file +sol[absb.output.u] From da5010501e0e7f9cfc362ea6cbf3d9edc0ca13a2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 8 Sep 2023 14:30:48 -0400 Subject: [PATCH 10/15] Fix typo --- src/systems/diffeqs/odesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 559103b91a..0abf12d26e 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -314,7 +314,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, drop_expr = drop_expr, - ps = paramteres(sys), + ps = parameters(sys), throw = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] From d6cc7bf9896e7af2ac58fc1ef2807c2b744f45a7 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 12 Sep 2023 15:40:04 -0400 Subject: [PATCH 11/15] Fix input output handling --- src/ModelingToolkit.jl | 2 +- src/systems/diffeqs/odesystem.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 89dee632f4..16e9ab6548 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -3,7 +3,7 @@ $(DocStringExtensions.README) """ module ModelingToolkit using PrecompileTools, Reexport -@recompile_invalidations begin +@recompile_invalidations begin using DocStringExtensions using Compat using AbstractTrees diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 0abf12d26e..76c6c4a58a 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -386,9 +386,8 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - pars = parameters(sys) if inputs !== nothing - pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list + ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list end if ps isa Tuple ps = DestructuredArgs.(ps, inbounds = !checkbounds) From 871efe70eb900d1e3f4123fc12c4e2f42c335d72 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 12 Sep 2023 15:42:43 -0400 Subject: [PATCH 12/15] split_parameters tests --- test/split_parameters.jl | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index f457e9e22e..031ed39606 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -50,6 +50,8 @@ prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]) @test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success +@test sol[y][end] == x[end] + # ------------------------ Mixed Type Converted to float (default behavior) @@ -76,21 +78,6 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false) @test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success -sol[states(model)] - -# ------------------------- Observables - -@named c = Sine(; frequency = 1) -@named absb = Abs(;) -@named int = Integrator(; k = 1) -@named model = ODESystem([ - connect(c.output, absb.input), - connect(absb.output, int.input), - ], - t, - systems = [int, absb, c]) -sys = structural_simplify(model) -prob = ODEProblem(sys, Pair[int.x => 0.0], (0.0, 1.0)) -sol = solve(prob, Rodas4()) -@test sol.retcode == ReturnCode.Success -sol[absb.output.u] + + + From 77a8792aa2a2a4eb6c3ec5b00ff5affebf8be099 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 12 Sep 2023 15:58:56 -0400 Subject: [PATCH 13/15] Format --- test/split_parameters.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 031ed39606..ef8f434dca 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -50,8 +50,7 @@ prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]) @test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success -@test sol[y][end] == x[end] - +@test sol[y][end] == x[end] # ------------------------ Mixed Type Converted to float (default behavior) @@ -78,6 +77,3 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false) @test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success - - - From d53907e8906f3f8dfdbe49e7196416b0be078f74 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Wed, 13 Sep 2023 15:24:44 -0400 Subject: [PATCH 14/15] add bool to promote_to_concrete --- src/utils.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 5b3bfdc26a..1d64d95a4e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -660,6 +660,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true) I = Int8 has_int = false has_array = false + has_bool = false array_T = nothing for v in vs if v isa AbstractArray @@ -672,6 +673,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = true) has_int = true I = promote_type(I, E) end + if E <: Bool + has_bool = true + end end if tofloat && !has_array C = float(C) @@ -682,6 +686,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = true) if has_int C = Union{C, I} end + if has_bool + C = Union{C, Bool} + end return copyto!(similar(vs, C), vs) end convert.(C, vs) From 6acee61eaceb230c6ffdfcbf6f21de6bdd9026a4 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Thu, 14 Sep 2023 06:05:44 -0400 Subject: [PATCH 15/15] fix for Catalyst.jl --- src/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 1d64d95a4e..abea65ab21 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -650,6 +650,11 @@ function promote_to_concrete(vs; tofloat = true, use_union = true) if isempty(vs) return vs end + if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array + tofloat = false + use_union = true + vs = Any[vs...] + end T = eltype(vs) if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do vs