Skip to content

Commit

Permalink
fix: fix remake handling identical variables with different names
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 4, 2024
1 parent 8763f7b commit 1d12433
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 21 deletions.
78 changes: 57 additions & 21 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,40 +543,76 @@ end

function fill_u0(prob, u0; defs = nothing, use_defaults = false)
vsyms = variable_symbols(prob)
if length(u0) == length(vsyms)
return u0
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in u0
idx = variable_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
end
newvals = anydict()
for sym in vsyms
varmap_has_var(u0, sym) && continue
def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing ||
(symbolic_type(defval) === NotSymbolic() && !use_defaults)
nothing
else
haskey(sym_to_idx, sym) && continue
idx = variable_index(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
idx_to_val[idx] = if defs !== nothing &&
(defval = varmap_get(defs, sym)) !== nothing &&
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getu(prob, sym)(prob)
end
newvals[sym] = @something def getu(prob, sym)(prob)
end
return merge(u0, newvals)
newvals = anydict()
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in u0
haskey(sym_to_idx, k) && continue
newvals[k] = v
end
return newvals
end

function fill_p(prob, p; defs = nothing, use_defaults = false)
psyms = parameter_symbols(prob)::Vector
if length(p) == length(psyms)
return p
psyms = parameter_symbols(prob)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in p
idx = parameter_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
end
newvals = anydict()
for sym in psyms
varmap_has_var(p, sym) && continue
def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing ||
(symbolic_type(defval) === NotSymbolic() && !use_defaults)
nothing
else
haskey(sym_to_idx, sym) && continue
idx = parameter_index(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
idx_to_val[idx] = if defs !== nothing &&
(defval = varmap_get(defs, sym)) !== nothing &&
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getp(prob, sym)(prob)
end
newvals[sym] = @something def getp(prob, sym)(prob)
end
return merge(p, newvals)
newvals = anydict()
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in p
haskey(sym_to_idx, k) && continue
newvals[k] = v
end
return newvals
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
Expand Down
19 changes: 19 additions & 0 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,22 @@ a = Remake_Test1(p = 1)
@test @inferred remake(a, p = 2) == Remake_Test1(p = 2)
@test @inferred remake(a, args = 1) == Remake_Test1(1, p = 1)
@test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1)

@testset "fill_u0 and fill_p ignore identical variables with different names" begin
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
function foo(du, u, p, t)
du .= u .* p
end
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
u0 = Dict(:x2 => 2)
newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys))
@test length(newu0) == 2
@test get(newu0, :x2, 0) == 2
@test get(newu0, :y, 0) == 2.5
p = Dict(:a2 => 3)
newp = SciMLBase.fill_p(prob, p; defs = default_values(sys))
@test length(newp) == 2
@test get(newp, :a2, 0) == 3
@test get(newp, :b, 0) == 4.5
end

0 comments on commit 1d12433

Please sign in to comment.