Skip to content

Commit

Permalink
Merge pull request #782 from AayushSabharwal/as/fix-remake-buffer
Browse files Browse the repository at this point in the history
fix: update to new `remake_buffer` signature
  • Loading branch information
ChrisRackauckas authored Sep 5, 2024
2 parents a121f72 + 0326659 commit ecb8f25
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.26"
SymbolicIndexingInterface = "0.3.30"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
24 changes: 12 additions & 12 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,56 +617,56 @@ end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), u0), p
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
for (k, v) in u0)

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), u0), p
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

# FIXME: need to provide `u` since the observed function expects it.
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
# used, since any state symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = state_values(prob), p = p)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in u0)
return remake_buffer(prob, state_values(prob), u0), p
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
end

function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
for (k, v) in p)

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

# FIXME: need to provide `p` since the observed function expects an `MTKParameters`
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
# used, since any parameter symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = u0, p = parameter_values(prob))
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in p)
return u0, remake_buffer(prob, parameter_values(prob), p)
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)

if !isu0dep && !ispdep
return remake_buffer(prob, state_values(prob), u0),
remake_buffer(prob, parameter_values(prob), p)
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
if !isu0dep
u0 = remake_buffer(prob, state_values(prob), u0)
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
end
if !ispdep
p = remake_buffer(prob, parameter_values(prob), p)
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
end

Expand All @@ -675,8 +675,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
for (k, v) in u0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in p)
return remake_buffer(prob, state_values(prob), u0),
remake_buffer(prob, parameter_values(prob), p)
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
Expand Down

0 comments on commit ecb8f25

Please sign in to comment.