Skip to content

Commit

Permalink
Merge pull request #734 from SciML/downstream
Browse files Browse the repository at this point in the history
Fix downstream test guess
  • Loading branch information
ChrisRackauckas authored Jul 20, 2024
2 parents f76fb4a + 868eb9e commit 0e5d5cd
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ext/SciMLBaseMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ end
function Makie.convert_arguments(::Type{<:Lines},
sim::SciMLBase.EnsembleSummary;
idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) :
1,
1,
error_style = :ribbon, ci_type = :quantile,
kwargs...)
if ci_type == :SEM
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ end

@recipe function f(sim::EnsembleSummary;
idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) :
1,
1,
error_style = :ribbon, ci_type = :quantile)
if ci_type == :SEM
if sim.u.u[1] isa AbstractArray
Expand Down
4 changes: 1 addition & 3 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ For specifying Jacobians and mass matrices, see the
* `kwargs`: The keyword arguments passed on to the solvers.
"""
mutable struct NonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
Expand Down Expand Up @@ -322,8 +322,6 @@ function Base.setproperty!(prob::NonlinearProblem, s::Symbol, v, order::Symbol)
Base.setfield!(prob, s, v, order)
end



@doc doc"""
Defines a nonlinear least squares problem.
Expand Down
5 changes: 3 additions & 2 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,16 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointSecondOrderBVProblem{iip}() : StandardSecondOrderBVProblem()
prob_type = TP ? TwoPointSecondOrderBVProblem{iip}() :
StandardSecondOrderBVProblem()

# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type===problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end

return new{typeof(_u0), typeof(_tspan), iip, typeof(nlls), typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
end
Expand Down
30 changes: 19 additions & 11 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ function remake(prob::AbstractSciMLProblem; u0 = missing,
_remake_internal(prob; kwargs..., u0, p)
end

function remake(prob::AbstractIntervalNonlinearProblem; p = missing, interpret_symbolicmap = true, use_defaults = false, kwargs...)
function remake(prob::AbstractIntervalNonlinearProblem; p = missing,
interpret_symbolicmap = true, use_defaults = false, kwargs...)
_, p = updated_u0_p(prob, [], p; interpret_symbolicmap, use_defaults)
_remake_internal(prob; kwargs..., p)
end
Expand Down Expand Up @@ -102,19 +103,25 @@ function remake(prob::ODEProblem; f = missing,
iip = isinplace(prob)

if f === missing
initializeprob, initializeprobmap = remake_initializeprob(prob.f.sys, prob.f, u0 === missing ? newu0 : u0, tspan[1], p === missing ? newp : p)
initializeprob, initializeprobmap = remake_initializeprob(
prob.f.sys, prob.f, u0 === missing ? newu0 : u0,
tspan[1], p === missing ? newp : p)
if specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1])); initializeprob, initializeprobmap)
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_iip(
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1]));
initializeprob, initializeprobmap)
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1])); initializeprob, initializeprobmap)
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1]));
initializeprob, initializeprobmap)
end
else
_f = prob.f
Expand All @@ -138,7 +145,8 @@ function remake(prob::ODEProblem; f = missing,
end

if kwargs === missing
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
BoundaryValueDiffEq = "5"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "8.37, 9"
ModelingToolkit = "9"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3"
Optimization = "3"
Expand Down
5 changes: 3 additions & 2 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ k = ShiftIndex(t)
# Roundabout method to avoid having to specify values for previous timestep
fn = DiscreteFunction(discsys)
ps = ModelingToolkit.MTKParameters(discsys, p)
discu0 = Dict([u0..., x(k-1) => 0.0, y(k-1) => 0.0, z(k-1) => 0.0])
discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0])
push!(syss, discsys)
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))

Expand Down Expand Up @@ -145,7 +145,8 @@ for (sys, prob) in zip(syss, probs)
prob2 = @inferred baseType remake(prob; u0 = [x => 0.5σ + 1], p ==> 0.5x + 1])
@test ugetter(prob2) [15.0, 0.0, 0.0]
@test pgetter(prob2) [28.0, 8.5, 10.0]
prob2 = @inferred baseType remake(prob; u0 = [sys.x => 0.5σ + 1], p = [sys.β => 0.5x + 1])
prob2 = @inferred baseType remake(
prob; u0 = [sys.x => 0.5σ + 1], p = [sys.β => 0.5x + 1])
@test ugetter(prob2) [15.0, 0.0, 0.0]
@test pgetter(prob2) [28.0, 8.5, 10.0]
# Not testing `Symbol => expr` since nested substitution doesn't work with that
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6)
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter
])
], defaults = [resistor1.n.v => 0.0])
end

@testset "DAE Observable function AD" begin
Expand All @@ -96,7 +96,7 @@ end
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
# sum(sol[sys.ampermeter.i])
# end
#
#
# @test isnothing(gs_mtkp)
# @test length(gs_p_new) == length(p_new)
# end

0 comments on commit 0e5d5cd

Please sign in to comment.