diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index b10b7125a2..5df951a566 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -51,7 +51,7 @@ function _model_macro(mod, name, expr, isconnector) c_evts = [] d_evts = [] kwargs = OrderedCollections.OrderedSet() - where_types = Expr[] + where_types = Union{Symbol, Expr}[] push!(exprs.args, :(variables = [])) push!(exprs.args, :(parameters = [])) @@ -143,9 +143,15 @@ end pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key) function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var, - varclass, where_types) + varclass, where_types, meta) if indices isa Nothing - push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing)) + if !isnothing(meta) && haskey(meta, VariableUnit) + uvar = gensym() + push!(where_types, uvar) + push!(kwargs, Expr(:kw, :($a::Union{Nothing, $uvar}), nothing)) + else + push!(kwargs, Expr(:kw, :($a::Union{Nothing, $type}), nothing)) + end dict[:kwargs][getname(var)] = Dict(:value => def, :type => type) else vartype = gensym(:T) @@ -154,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var, Expr(:(::), a, Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))), nothing)) - push!(where_types, :($vartype <: $type)) + if !isnothing(meta) && haskey(meta, VariableUnit) + push!(where_types, vartype) + else + push!(where_types, :($vartype <: $type)) + end dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type}) end if dict[varclass] isa Vector @@ -166,7 +176,7 @@ end function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing, - type::Type = Real) + type::Type = Real, meta = Dict{DataType, Expr}()) metatypes = [(:connection_type, VariableConnectType), (:description, VariableDescription), (:unit, VariableUnit), @@ -186,29 +196,31 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; a::Symbol => begin var = generate_var!(dict, a, varclass; indices, type) update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var, - varclass, where_types) + varclass, where_types, meta) return var, def, Dict() end Expr(:(::), a, type) => begin type = getfield(mod, type) - parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type) + parse_variable_def!( + dict, mod, a, varclass, kwargs, where_types; def, type, meta) end Expr(:(::), Expr(:call, a, b), type) => begin type = getfield(mod, type) def = _type_check!(def, a, type, varclass) - parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type) + parse_variable_def!( + dict, mod, a, varclass, kwargs, where_types; def, type, meta) end Expr(:call, a, b) => begin var = generate_var!(dict, a, b, varclass, mod; indices, type) update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var, - varclass, where_types) + varclass, where_types, meta) return var, def, Dict() end Expr(:(=), a, b) => begin Base.remove_linenums!(b) def, meta = parse_default(mod, b) var, def, _ = parse_variable_def!( - dict, mod, a, varclass, kwargs, where_types; def, type) + dict, mod, a, varclass, kwargs, where_types; def, type, meta) if dict[varclass] isa Vector dict[varclass][1][getname(var)][:default] = def else @@ -231,9 +243,9 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; return var, def, Dict() end Expr(:tuple, a, b) => begin - var, def, _ = parse_variable_def!( - dict, mod, a, varclass, kwargs, where_types; type) meta = parse_metadata(mod, b) + var, def, _ = parse_variable_def!( + dict, mod, a, varclass, kwargs, where_types; type, meta) if meta !== nothing for (type, key) in metatypes if (mt = get(meta, key, nothing)) !== nothing @@ -253,7 +265,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; Expr(:ref, a, b...) => begin indices = map(i -> UnitRange(i.args[2], i.args[end]), b) parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; - def, indices, type) + def, indices, type, meta) end _ => error("$arg cannot be parsed") end @@ -611,16 +623,58 @@ function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_ push!(exprs, ex) end +function convert_units(varunits::DynamicQuantities.Quantity, value) + DynamicQuantities.ustrip(DynamicQuantities.uconvert( + DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) +end + +function convert_units( + varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where {T} + DynamicQuantities.ustrip.(DynamicQuantities.uconvert.( + DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) +end + +function convert_units(varunits::Unitful.FreeUnits, value) + Unitful.ustrip(varunits, value) +end + +function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T} + Unitful.ustrip.(varunits, value) +end + function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) vv, def, metadata_with_exprs = parse_variable_def!( dict, mod, arg, varclass, kwargs, where_types) name = getname(vv) - varexpr = quote - $name = if $name === nothing - $setdefault($vv, $def) - else - $setdefault($vv, $name) + varexpr = if haskey(metadata_with_exprs, VariableUnit) + unit = metadata_with_exprs[VariableUnit] + quote + $name = if $name === nothing + $setdefault($vv, $def) + else + try + $setdefault($vv, $convert_units($unit, $name)) + catch e + if isa(e, $(DynamicQuantities.DimensionError)) || + isa(e, $(Unitful.DimensionError)) + error("Unable to convert units for \'" * string(:($$vv)) * "\'") + elseif isa(e, MethodError) + error("No or invalid units provided for \'" * string(:($$vv)) * + "\'") + else + rethrow(e) + end + end + end + end + else + quote + $name = if $name === nothing + $setdefault($vv, $def) + else + $setdefault($vv, $name) + end end end diff --git a/test/dq_units.jl b/test/dq_units.jl index 75bbd4c4a9..76a9c5aa11 100644 --- a/test/dq_units.jl +++ b/test/dq_units.jl @@ -157,3 +157,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1]) maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) maj2 = MassActionJump(γ, [S => 1], [S => -1]) @named js4 = JumpSystem([maj1, maj2], ModelingToolkit.t_nounits, [S], [β, γ]) + +@mtkmodel ParamTest begin + @parameters begin + a, [unit = u"m"] + end + @variables begin + b(t), [unit = u"kg"] + end +end + +@named sys = ParamTest() + +@named sys = ParamTest(a = 3.0u"cm") +@test ModelingToolkit.getdefault(sys.a) ≈ 0.03 + +@test_throws ErrorException ParamTest(; name = :t, a = 1.0) +@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s") + +@mtkmodel ArrayParamTest begin + @parameters begin + a[1:2], [unit = u"m"] + end +end + +@named sys = ArrayParamTest() + +@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm") +@test ModelingToolkit.getdefault(sys.a) ≈ [0.01, 0.03] diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 9236dfc975..352c2ee3d4 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1,7 +1,8 @@ using ModelingToolkit, Test using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata, get_systems, get_ps, getdefault, getname, readable_code, - scalarize, symtype, VariableDescription, RegularConnector + scalarize, symtype, VariableDescription, RegularConnector, + get_unit using URIs: URI using Distributions using DynamicQuantities, OrdinaryDiffEq @@ -53,8 +54,9 @@ end end end -@named p = Pin(; v = π) -@test getdefault(p.v) == π +@named p = Pin(; v = π * u"V") + +@test getdefault(p.v) ≈ π @test Pin.isconnector == true @mtkmodel OnePort begin @@ -76,7 +78,6 @@ end @test OnePort.isconnector == false -resistor_log = "$(@__DIR__)/logo/resistor.svg" @mtkmodel Resistor begin @extend v, i = oneport = OnePort() @parameters begin @@ -105,14 +106,14 @@ end @parameters begin C, [unit = u"F"] end - @extend OnePort(; v = 0.0) + @extend OnePort(; v = 0.0u"V") @icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg" @equations begin D(v) ~ i / C end end -@named capacitor = Capacitor(C = 10, v = 10.0) +@named capacitor = Capacitor(C = 10u"F", v = 10.0u"V") @test getdefault(capacitor.v) == 10.0 @mtkmodel Voltage begin @@ -127,9 +128,9 @@ end @mtkmodel RC begin @structural_parameters begin - R_val = 10 - C_val = 10 - k_val = 10 + R_val = 10u"Ω" + C_val = 10u"F" + k_val = 10u"V" end @components begin resistor = Resistor(; R = R_val) @@ -147,9 +148,9 @@ end end end -C_val = 20 -R_val = 20 -res__R = 100 +C_val = 20u"F" +R_val = 20u"Ω" +res__R = 100u"Ω" @mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R) prob = ODEProblem(rc, [], (0, 1e9)) sol = solve(prob, Rodas5P()) @@ -160,11 +161,12 @@ resistor = getproperty(rc, :resistor; namespace = false) @test getname(rc.resistor.R) === getname(resistor.R) @test getname(rc.resistor.v) === getname(resistor.v) # Test that `resistor.R` overrides `R_val` in the argument. -@test getdefault(rc.resistor.R) == res__R != R_val +@test getdefault(rc.resistor.R) * get_unit(rc.resistor.R) == res__R != R_val # Test that `C_val` passed via argument is set as default of C. -@test getdefault(rc.capacitor.C) == C_val +@test getdefault(rc.capacitor.C) * get_unit(rc.capacitor.C) == C_val # Test that `k`'s default value is unchanged. -@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val][:value] +@test getdefault(rc.constant.k) * get_unit(rc.constant.k) == + eval(RC.structure[:kwargs][:k_val][:value]) @test getdefault(rc.capacitor.v) == 0.0 @test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] == diff --git a/test/units.jl b/test/units.jl index 033a64c0e3..8d7f9e451e 100644 --- a/test/units.jl +++ b/test/units.jl @@ -192,3 +192,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1]) maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) maj2 = MassActionJump(γ, [S => 1], [S => -1]) @named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ]) + +@mtkmodel ParamTest begin + @parameters begin + a, [unit = u"m"] + end + @variables begin + b(t), [unit = u"kg"] + end +end + +@named sys = ParamTest() + +@named sys = ParamTest(a = 3.0u"cm") +@test ModelingToolkit.getdefault(sys.a) ≈ 0.03 + +@test_throws ErrorException ParamTest(; name = :t, a = 1.0) +@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s") + +@mtkmodel ArrayParamTest begin + @parameters begin + a[1:2], [unit = u"m"] + end +end + +@named sys = ArrayParamTest() + +@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm") +@test ModelingToolkit.getdefault(sys.a) ≈ [0.01, 0.03]