diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d0474f1ba8..a350211386 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys end end +function wrap_assignments(isscalar, assignments; let_block = false) + function wrapper(expr) + Func(expr.args, [], Let(assignments, expr.body, let_block)) + end + if isscalar + wrapper + else + wrapper, wrapper + end +end + function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys)) isscalar = !(exprs isa AbstractArray) array_vars = Dict{Any, AbstractArray{Int}}() @@ -505,7 +516,7 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) ts_idx = nothing end rawobs = build_explicit_observed_function( - sys, sym; param_only = true, return_inplace = true) + sys, sym; param_only = true, return_inplace = true) if rawobs isa Tuple obsfn = let oop = rawobs[1], iip = rawobs[2] f1(p::MTKParameters, t) = oop(p..., t) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 12e2ff16b2..b6ed52e49d 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -487,17 +487,23 @@ function build_explicit_observed_function(sys, ts; if inputs === nothing args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...] else - ipts = DestructuredArgs(inputs, inbounds = !checkbounds) + ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds) args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...] end pre = get_postprocess_fbody(sys) - res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression}) + res = build_function(isscalar ? ts[1] : ts, + args...; + postprocess_fbody = pre, + wrap_code = wrap_array_vars( + sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)) .∘ + wrap_assignments(isscalar, obsexprs), + expression = Val{expression}) if isscalar || return_inplace return res else return res[1] end - + ex = Func(args, [], pre(Let(obsexprs, isscalar ? ts[1] : MakeArray(ts, output_type),