diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 61f16fd926..b31f39ec22 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -202,6 +202,7 @@ struct ODESystem <: AbstractODESystem check_parameters(ps, iv) check_equations(deqs, iv) check_equations(equations(cevents), iv) + check_var_types(ODESystem, dvs) end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(dvs, ps, iv) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 366e8a6d09..eeedce1d46 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -170,6 +170,7 @@ struct SDESystem <: AbstractODESystem check_parameters(ps, iv) check_equations(deqs, iv) check_equations(neqs, dvs) + check_var_types(SDESystem, dvs) if size(neqs, 1) != length(deqs) throw(ArgumentError("Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size(neqs,1)) != length(deqs) = $(length(deqs))")) end diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index bd5c72eec7..fb6dfe0236 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -122,6 +122,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem check_independent_variables([iv]) check_variables(dvs, iv) check_parameters(ps, iv) + check_var_types(DiscreteSystem, dvs) end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(dvs, ps, iv) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 5c0a61771a..e5dbb185e2 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -147,6 +147,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_independent_variables([iv]) check_variables(unknowns, iv) check_parameters(ps, iv) + check_var_types(JumpSystem, unknowns) end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(unknowns, ps, iv) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 6b0b0cc759..687d55b9bb 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -115,6 +115,9 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; checks::Union{Bool, Int} = true) + if checks == true || (checks & CheckComponents) > 0 + check_var_types(NonlinearSystem, unknowns) + end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(unknowns, ps) check_units(u, eqs) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 860e063e35..de4696bfa2 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -73,6 +73,9 @@ struct OptimizationSystem <: AbstractOptimizationSystem gui_metadata = nothing, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; checks::Union{Bool, Int} = true) + if checks == true || (checks & CheckComponents) > 0 + check_var_types(OptimizationSystem, unknowns) + end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(unknowns, ps) unwrap(op) isa Symbolic && check_units(u, op) diff --git a/src/systems/pde/pdesystem.jl b/src/systems/pde/pdesystem.jl index eac540e401..e067f0e149 100644 --- a/src/systems/pde/pdesystem.jl +++ b/src/systems/pde/pdesystem.jl @@ -102,6 +102,9 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem checks::Union{Bool, Int} = true, description = "", name) + if checks == true || (checks & CheckComponents) > 0 + check_var_types(PDESystem, dvs) + end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(dvs, ivs, ps) check_units(u, eqs) diff --git a/src/utils.jl b/src/utils.jl index 5e4f0b52d2..736a4cc03e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,6 +155,17 @@ function check_variables(dvs, iv) end end +function check_var_types(sys_type::Type{T}, dvs) where T <: AbstractSystem + if any(u -> !(symtype(u) <: Number || eltype(symtype(u)) <: Number), dvs) + error("The type of unknown variables must be a numeric type.") + elseif any(u -> (symtype(u) !== symtype(dvs[1])), dvs) + error("The type of all the unknown variables in a system must all be the same.") + elseif sys_type == ODESystem || sys_type == SDESystem || sys_type == PDESystem + any(u -> !(symtype(u) == Real || eltype(symtype(u)) == Real), dvs) && error("The type of unknown variables for an SDESystem, PDESystem, or ODESystem should not be a concrete numeric type.") + end + nothing +end + function check_lhs(eq::Equation, op, dvs::Set) v = unwrap(eq.lhs) _iszero(v) && return @@ -1182,3 +1193,4 @@ function guesses_from_metadata!(guesses, vars) guesses[vars[i]] = varguesses[i] end end + diff --git a/test/odesystem.jl b/test/odesystem.jl index 76c47a8f1d..b31b58002a 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1553,6 +1553,13 @@ end @test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops) end +@testset "Validate input types" begin + @parameters p d + @variables X(t)::Int64 + eq = D(X) ~ p - d*X + @test_throws Exception @mtkbuild osys = ODESystem([eq], t) +end + @testset "dae_order_lowering basic test" begin @parameters a @variables x(t) y(t) z(t) @@ -1608,4 +1615,4 @@ end prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...)) @test prob.u0 isa SVector -end \ No newline at end of file +end diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 8069581dcc..f638eaed9d 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -868,3 +868,12 @@ end @test length(ModelingToolkit.get_noiseeqs(sys)) == 1 @test length(observed(sys)) == 1 end + +# Test validating types of states +@testset "Validate input types" begin + @parameters p d + @variables X(t)::Int64 + @brownian z + eq2 = D(X) ~ p - d*X + z + @test_throws Exception @mtkbuild ssys = System([eq2], t) +end