diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 0a2ce6ea1..1f09d55c3 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -4,10 +4,10 @@ using Zygote using Zygote: @adjoint, pullback import Zygote: literal_getproperty using SciMLBase -using SciMLBase: ODESolution, sym_to_index, remake, +using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution -using SymbolicIndexingInterface: symbolic_type, NotSymbolic +using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in @@ -34,7 +34,7 @@ end @adjoint function getindex(VA::ODESolution, sym, j::Int) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym du, dprob = if i === nothing getter = getobserved(VA) grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) @@ -96,7 +96,7 @@ end @adjoint function getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) - i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym + i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if i === nothing throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 07d73172e..57e596270 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -209,7 +209,7 @@ end end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon) +Base.@propagate_inbounds function RecursiveArrayTools._getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon) return [xi[s] for xi in x.u] end diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index ebfdf0c73..8aad874e5 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -450,18 +450,17 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol) end end -Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int}, +Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, BitArray, AbstractArray{Bool}}...) A.u[I...] end -Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym) +Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym) if is_variable(A, sym) return A[variable_index(A, sym)] elseif is_parameter(A, sym) - Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex) - return getp(A, sym)(A) + error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing") elseif is_independent_variable(A, sym) return A.t elseif is_observed(A, sym) @@ -471,11 +470,11 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymboli end end -Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym) +Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym) return A[collect(sym)] end -Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}) +Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}) return getindex.((A,), sym) end @@ -484,12 +483,20 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym) elsymtype = symbolic_type(eltype(sym)) if symtype != NotSymbolic() - return getindex(A, symtype, sym) + return _getindex(A, symtype, sym) else - return getindex(A, elsymtype, sym) + return _getindex(A, elsymtype, sym) end end +Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.SolvedVariables) + return getindex(A, variable_symbols(A)) +end + +Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.AllVariables) + return getindex(A, all_variable_symbols(A)) +end + function observed(A::DEIntegrator, sym) getobserved(A)(sym, A.u, A.p, A.t) end @@ -500,8 +507,7 @@ function Base.setindex!(A::DEIntegrator, val, sym) if is_variable(A, sym) A.u[variable_index(A, sym)] = val elseif is_parameter(A, sym) - Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex) - setp(A, sym)(A, val) + error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.") else error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.") end diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index f8eb35a6e..32f912055 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -1,13 +1,20 @@ SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p +Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables) + return getindex(prob, variable_symbols(prob)) +end + +Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.AllVariables) + return getindex(prob, all_variable_symbols(prob)) +end + Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym) if symbolic_type(sym) == ScalarSymbolic() if is_variable(prob.f, sym) return prob.u0[variable_index(prob.f, sym)] elseif is_parameter(prob.f, sym) - Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex) - return getp(prob, sym)(prob) + error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.") elseif is_independent_variable(prob.f, sym) return getindepsym(prob) elseif is_observed(prob.f, sym) @@ -37,8 +44,7 @@ function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym) if is_variable(prob.f, sym) prob.u0[variable_index(prob.f, sym)] = val elseif is_parameter(prob.f, sym) - Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex) - setp(prob, sym)(prob, val) + error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.") else error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.") end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 33f3d1d8f..c1d291da3 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -179,22 +179,30 @@ end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") - augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1] + if is_parameter(sol, idxs) + return getp(sol, idxs)(sol) + else + return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1] + end end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector, continuity) where {deriv} all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) - [first(interp_sol[idx]) for idx in idxs] + [is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs] end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") - interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) - p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - return DiffEqArray(interp_sol[idxs], t, p, sol) + if is_parameter(sol, idxs) + return getp(sol, idxs)(sol) + else + interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) + p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + return DiffEqArray(interp_sol[idxs], t, p, sol) + end end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 5e1bf1b72..95b19417b 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -73,8 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) if is_variable(A, sym) return A[variable_index(A, sym)] elseif is_parameter(A, sym) - Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex) - return getp(A, sym)(A) + error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.") elseif is_observed(A, sym) return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p) else @@ -88,6 +87,14 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) end end +Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.SolvedVariables) + return getindex(A, variable_symbols(A)) +end + +Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.AllVariables) + return getindex(A, all_variable_symbols(A)) +end + function observed(A::AbstractTimeseriesSolution, sym, i::Int) getobserved(A)(sym, A[i], A.prob.p, A.t[i]) end diff --git a/test/downstream/ensemble_nondes.jl b/test/downstream/ensemble_nondes.jl index 789bdb4fc..970512ef0 100644 --- a/test/downstream/ensemble_nondes.jl +++ b/test/downstream/ensemble_nondes.jl @@ -17,7 +17,7 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistr @test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5]) -ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2))) +ensembleprob = Optimization.EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2))) sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5) @test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective @@ -35,4 +35,4 @@ ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ ra sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100) -sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100) \ No newline at end of file +sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100) diff --git a/test/downstream/integrator_indexing.jl b/test/downstream/integrator_indexing.jl index f5f180b03..c7dcf73fa 100644 --- a/test/downstream/integrator_indexing.jl +++ b/test/downstream/integrator_indexing.jl @@ -18,9 +18,9 @@ tspan = (0.0, 1000000.0) oprob = ODEProblem(population_model, u0, tspan, p) integrator = init(oprob, Rodas4()) -@test_deprecated integrator[a] -@test_deprecated integrator[population_model.a] -@test_deprecated integrator[:a] +@test_throws Exception integrator[a] +@test_throws Exception integrator[population_model.a] +@test_throws Exception integrator[:a] @test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0 @test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0 @test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0 @@ -28,7 +28,8 @@ integrator = init(oprob, Rodas4()) @test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0 @test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0 - +@test integrator[solvedvariables] == integrator.u +@test integrator[allvariables] == integrator.u step!(integrator, 100.0, true) @test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0 @@ -299,6 +300,6 @@ eqs = [collect(D.(x) .~ x) D(y) ~ norm(x) * y - x[1]] @named sys = ODESystem(eqs, t, [sts...;], [ps...;]) prob = ODEProblem(sys, [], (0, 1.0)) -@test_broken local integrator = init(prob, Tsit5()) -@test_broken integrator[x] isa Vector{<:Vector} -@test_broken integrator[@nonamespace sys.x] isa Vector{<:Vector} +integrator = init(prob, Tsit5()) +@test integrator[x] isa Vector{Float64} +@test integrator[@nonamespace sys.x] isa Vector{Float64} diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 3e4cb9cef..67c10c8a0 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -1,4 +1,5 @@ using ModelingToolkit, OrdinaryDiffEq, Test +using SymbolicIndexingInterface @parameters σ ρ β @variables t x(t) y(t) z(t) @@ -26,20 +27,37 @@ tspan = (0.0, 100.0) # ODEProblem. oprob = ODEProblem(sys, u0, tspan, p, jac = true) -@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 28.0 -@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 10.0 -@test oprob[β] == oprob[sys.β] == oprob[:β] == 8 / 3 +@test_throws Exception oprob[σ] +@test_throws Exception oprob[sys.σ] +@test_throws Exception oprob[:σ] +getσ1 = getp(sys, σ) +getσ2 = getp(sys, sys.σ) +getσ3 = getp(sys, :σ) +@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0 +getρ1 = getp(sys, ρ) +getρ2 = getp(sys, sys.ρ) +getρ3 = getp(sys, :ρ) +@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0 +getβ1 = getp(sys, β) +getβ2 = getp(sys, sys.β) +getβ3 = getp(sys, :β) +@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3 @test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0 @test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0 @test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0 - -oprob[σ] = 10.0 -@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 10.0 -oprob[sys.ρ] = 20.0 -@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 20.0 -oprob[σ] = 30.0 -@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 30.0 +@test oprob[solvedvariables] == oprob[variable_symbols(sys)] +@test oprob[allvariables] == oprob[all_variable_symbols(sys)] + +setσ = setp(sys, σ) +setσ(oprob, 10.0) +@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0 +setρ = setp(sys, sys.ρ) +setρ(oprob, 20.0) +@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0 +setβ = setp(sys, :β) +setβ(oprob, 30.0) +@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0 oprob[x] = 10.0 @test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0 @@ -56,20 +74,22 @@ noiseeqs = [0.1 * x, sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p) u0 -@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 28.0 -@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 10.0 -@test sprob[β] == sprob[noise_sys.β] == sprob[:β] == 8 / 3 +@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0 +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0 +@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3 @test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0 @test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0 @test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0 -sprob[σ] = 10.0 -@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 10.0 -sprob[noise_sys.ρ] = 20.0 -@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 20.0 -sprob[σ] = 30.0 -@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 30.0 +setσ(sprob, 10.0) +@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0 +setρ(sprob, 20.0) +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0 +setp(noise_sys, noise_sys.ρ)(sprob, 25.0) +@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0 +setβ(sprob, 30.0) +@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0 sprob[x] = 10.0 @test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0 diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 4e89a968d..88e312a92 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -21,9 +21,9 @@ sol = solve(oprob, Rodas4()) @test sol[s1] == sol[population_model.s1] == sol[:s1] @test sol[s2] == sol[population_model.s2] == sol[:s2] @test sol[s1][end] ≈ 1.0 -@test_deprecated sol[a] -@test_deprecated sol[population_model.a] -@test_deprecated sol[:a] +@test_throws Exception sol[a] +@test_throws Exception sol[population_model.a] +@test_throws Exception sol[:a] # Tests on SDEProblem noiseeqs = [0.1 * s1, @@ -34,9 +34,9 @@ sol = solve(sprob, ImplicitEM()) @test sol[s1] == sol[noisy_population_model.s1] == sol[:s1] @test sol[s2] == sol[noisy_population_model.s2] == sol[:s2] -@test_deprecated sol[a] -@test_deprecated sol[noisy_population_model.a] -@test_deprecated sol[:a] +@test_throws Exception sol[a] +@test_throws Exception sol[noisy_population_model.a] +@test_throws Exception sol[:a] ### Tests on layered model (some things should not work). ### @parameters t σ ρ β diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index f9fc43c23..c74f928ef 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -49,9 +49,13 @@ sol = solve(prob, Rodas4()) @test_throws Any sol['a', [1, 2, 3]] @test sol[a] isa AbstractVector +@test sol[:a] == sol[a] @test sol[a, 1] isa Real +@test sol[:a, 1] == sol[a, 1] @test sol[a, 1:5] isa AbstractVector +@test sol[:a, 1:5] == sol[a, 1:5] @test sol[a, [1, 2, 3]] isa AbstractVector +@test sol[:a, [1, 2, 3]] == sol[a, [1, 2, 3]] @test sol[:, 1] isa AbstractVector @test sol[:, 1:2] isa AbstractDiffEqArray @@ -68,7 +72,7 @@ sol = solve(prob, Rodas4()) @test sol[α, 3] isa Float64 @test length(sol[α, 5:10]) == 6 @test getp(prob, γ)(sol) isa Real -@test getp(prob, γ)(sol) == 2.0 +@test getp(prob, γ)(sol) == getp(prob, :γ)(sol) == 2.0 @test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple @test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}} @@ -182,16 +186,16 @@ plot(sol,idxs=(t,α)) using LinearAlgebra @variables t -sts = @variables x[1:3](t)=[1, 2, 3.0] y(t)=1.0 +sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 ps = @parameters p[1:3] = [1, 2, 3] D = Differential(t) eqs = [collect(D.(x) .~ x) D(y) ~ norm(x) * y - x[1]] @named sys = ODESystem(eqs, t, [sts...;], [ps...;]) prob = ODEProblem(sys, [], (0, 1.0)) -@test_broken sol = solve(prob, Tsit5()) -@test_broken sol[x] isa Vector{<:Vector} -@test_broken sol[@nonamespace sys.x] isa Vector{<:Vector} +sol = solve(prob, Tsit5()) +@test sol[x] isa Vector{<:Vector} +@test sol[@nonamespace sys.x] isa Vector{<:Vector} # accessing parameters @variables t x(t)