diff --git a/src/remake.jl b/src/remake.jl index bed4a8b5d..9e059329e 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -543,40 +543,76 @@ end function fill_u0(prob, u0; defs = nothing, use_defaults = false) vsyms = variable_symbols(prob) - if length(u0) == length(vsyms) - return u0 + sym_to_idx = anydict() + idx_to_sym = anydict() + idx_to_val = anydict() + for (k, v) in u0 + idx = variable_index(prob, k) + idx === nothing && continue + sym_to_idx[k] = idx + idx_to_sym[idx] = k + idx_to_val[idx] = v end - newvals = anydict() for sym in vsyms - varmap_has_var(u0, sym) && continue - def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing || - (symbolic_type(defval) === NotSymbolic() && !use_defaults) - nothing - else + haskey(sym_to_idx, sym) && continue + idx = variable_index(prob, sym) + haskey(idx_to_val, idx) && continue + sym_to_idx[sym] = idx + idx_to_sym[idx] = sym + idx_to_val[idx] = if defs !== nothing && + (defval = varmap_get(defs, sym)) !== nothing && + (symbolic_type(defval) != NotSymbolic() || use_defaults) defval + else + getu(prob, sym)(prob) end - newvals[sym] = @something def getu(prob, sym)(prob) end - return merge(u0, newvals) + newvals = anydict() + for (idx, val) in idx_to_val + newvals[idx_to_sym[idx]] = val + end + for (k, v) in u0 + haskey(sym_to_idx, k) && continue + newvals[k] = v + end + return newvals end function fill_p(prob, p; defs = nothing, use_defaults = false) - psyms = parameter_symbols(prob)::Vector - if length(p) == length(psyms) - return p + psyms = parameter_symbols(prob) + sym_to_idx = anydict() + idx_to_sym = anydict() + idx_to_val = anydict() + for (k, v) in p + idx = parameter_index(prob, k) + idx === nothing && continue + sym_to_idx[k] = idx + idx_to_sym[idx] = k + idx_to_val[idx] = v end - newvals = anydict() for sym in psyms - varmap_has_var(p, sym) && continue - def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing || - (symbolic_type(defval) === NotSymbolic() && !use_defaults) - nothing - else + haskey(sym_to_idx, sym) && continue + idx = parameter_index(prob, sym) + haskey(idx_to_val, idx) && continue + sym_to_idx[sym] = idx + idx_to_sym[idx] = sym + idx_to_val[idx] = if defs !== nothing && + (defval = varmap_get(defs, sym)) !== nothing && + (symbolic_type(defval) != NotSymbolic() || use_defaults) defval + else + getp(prob, sym)(prob) end - newvals[sym] = @something def getp(prob, sym)(prob) end - return merge(p, newvals) + newvals = anydict() + for (idx, val) in idx_to_val + newvals[idx_to_sym[idx]] = val + end + for (k, v) in p + haskey(sym_to_idx, k) && continue + newvals[k] = v + end + return newvals end function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}) diff --git a/test/remake_tests.jl b/test/remake_tests.jl index ad679bc61..e65f91bd6 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -295,3 +295,22 @@ a = Remake_Test1(p = 1) @test @inferred remake(a, p = 2) == Remake_Test1(p = 2) @test @inferred remake(a, args = 1) == Remake_Test1(1, p = 1) @test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1) + +@testset "fill_u0 and fill_p ignore identical variables with different names" begin + sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2), + :t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)) + function foo(du, u, p, t) + du .= u .* p + end + prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5]) + u0 = Dict(:x2 => 2) + newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys)) + @test length(newu0) == 2 + @test get(newu0, :x2, 0) == 2 + @test get(newu0, :y, 0) == 2.5 + p = Dict(:a2 => 3) + newp = SciMLBase.fill_p(prob, p; defs = default_values(sys)) + @test length(newp) == 2 + @test get(newp, :a2, 0) == 3 + @test get(newp, :b, 0) == 4.5 +end