From 721c2a7007511f45cf59e826d7da2790971e69ef Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 May 2024 20:12:45 +0530 Subject: [PATCH 1/2] feat: add parameter timeseries support to `AbstractDiffEqArray` --- src/vector_of_array.jl | 88 +++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 2c13a573..36f95ed2 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -60,11 +60,13 @@ A[1, :] # all time periods for f(t) A.t ``` """ -mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A} +mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <: + AbstractDiffEqArray{T, N, A} u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}} t::B p::F sys::S + discretes::D end ### Abstract Interface struct AllObserved @@ -174,29 +176,32 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p = nothing, - sys = nothing) where {T, N} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec, + sys = nothing; discretes = nothing) where {T, N} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, ts, p, - sys) + sys, + discretes) end # ambiguity resolution function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec, + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec, ts, nothing, + nothing, nothing) end function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, - ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec, + ::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, ts, p, - nothing) + nothing, + discretes) end # Assume that the first element is representative of all other elements @@ -206,7 +211,8 @@ function DiffEqArray(vec::AbstractVector, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing) + independent_variables = nothing, + discretes = nothing) sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -219,11 +225,13 @@ function DiffEqArray(vec::AbstractVector, typeof(vec), typeof(ts), typeof(p), - typeof(sys) + typeof(sys), + typeof(discretes) }(vec, ts, p, - sys) + sys, + discretes) end function DiffEqArray(vec::AbstractVector{VT}, @@ -232,7 +240,8 @@ function DiffEqArray(vec::AbstractVector{VT}, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} + independent_variables = nothing, + discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -243,18 +252,30 @@ function DiffEqArray(vec::AbstractVector{VT}, typeof(vec), typeof(ts), typeof(p), - typeof(sys) + typeof(sys), + typeof(discretes), }(vec, ts, p, - sys) + sys, + discretes) end +has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes) +get_discretes(x) = getfield(x, :discretes) + SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries() +function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B, + F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy} + Timeseries() +end SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys +function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray) + return get_discretes(A) +end Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A)) Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian() @@ -363,39 +384,18 @@ end # Symbolic Indexing Methods for (symtype, elsymtype, valtype, errcheck) in [ - (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), - (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), + (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), + (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), (NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray}, - :(all(x -> is_parameter(A, x), sym))), + :(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))), ] -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym)(A) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, arg) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym)(A, arg) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym).((A,), arg) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, ::Colon) - if $errcheck - throw(ParameterIndexingError(sym)) + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg...) + if $errcheck + throw(ParameterIndexingError(sym)) + end + getu(A, sym)(A, arg...) end - getu(A, sym)(A) -end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, From 33b2fe33d24f4c4dcd1d8bdaf4c726f548716fae Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Jun 2024 12:41:56 +0530 Subject: [PATCH 2/2] build: bump SII compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 15c2e649..bb434c6e 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ StaticArrays = "1.6" StaticArraysCore = "1.4" Statistics = "1.10" StructArrays = "0.6.11" -SymbolicIndexingInterface = "0.3.20" +SymbolicIndexingInterface = "0.3.23" Tables = "1.11" Test = "1" Tracker = "0.2.15"