Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,10 @@ function generate_connection_set!(connectionsets, domain_csets,
else
if lhs isa Number || lhs isa Symbolic
push!(eqs, eq) # split connections and equations
elseif lhs isa Connection
if get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
else
push!(cts, get_systems(rhs))
end
elseif lhs isa Connection && get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
else
error("$eq is not a legal equation!")
push!(cts, get_systems(rhs))
end
end
end
Expand Down
49 changes: 43 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,37 @@ function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = paramete
simplify = false, kwargs...)
tgrad = calculate_tgrad(sys, simplify = simplify)
pre = get_preprocess_constants(tgrad)
return build_function(tgrad, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
if ps isa Tuple
return build_function(tgrad,
dvs,
ps...,
get_iv(sys);
postprocess_fbody = pre,
kwargs...)
else
return build_function(tgrad,
dvs,
ps,
get_iv(sys);
postprocess_fbody = pre,
kwargs...)
end
end

function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
simplify = false, sparse = false, kwargs...)
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
pre = get_preprocess_constants(jac)
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
if ps isa Tuple
return build_function(jac,
dvs,
ps...,
get_iv(sys);
postprocess_fbody = pre,
kwargs...)
else
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
end
end

function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys),
Expand Down Expand Up @@ -364,8 +387,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
tgrad_oop, tgrad_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
tgrad_gen
_tgrad(u, p, t) = tgrad_oop(u, p, t)
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
if p isa Tuple
__tgrad(u, p, t) = tgrad_oop(u, p..., t)
__tgrad(J, u, p, t) = tgrad_iip(J, u, p..., t)
_tgrad = __tgrad
else
___tgrad(u, p, t) = tgrad_oop(u, p, t)
___tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
_tgrad = ___tgrad
end
else
_tgrad = nothing
end
Expand All @@ -379,8 +409,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
jac_oop, jac_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
jac_gen
_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
if p isa Tuple
__jac(u, p, t) = jac_oop(u, p..., t)
__jac(J, u, p, t) = jac_iip(J, u, p..., t)
_jac = __jac
else
___jac(u, p, t) = jac_oop(u, p, t)
___jac(J, u, p, t) = jac_iip(J, u, p, t)
_jac = ___jac
end
else
_jac = nothing
end
Expand Down