Skip to content

Commit

Permalink
Merge pull request #2898 from contradict/variable_value_units
Browse files Browse the repository at this point in the history
Make default value units consistent
  • Loading branch information
ChrisRackauckas authored Aug 7, 2024
2 parents 3f7ad46 + 49cf9ef commit e037dfe
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 33 deletions.
90 changes: 72 additions & 18 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function _model_macro(mod, name, expr, isconnector)
c_evts = []
d_evts = []
kwargs = OrderedCollections.OrderedSet()
where_types = Expr[]
where_types = Union{Symbol, Expr}[]

push!(exprs.args, :(variables = []))
push!(exprs.args, :(parameters = []))
Expand Down Expand Up @@ -143,9 +143,15 @@ end
pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key)

function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
if indices isa Nothing
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
if !isnothing(meta) && haskey(meta, VariableUnit)
uvar = gensym()
push!(where_types, uvar)
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $uvar}), nothing))
else
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $type}), nothing))
end
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
else
vartype = gensym(:T)
Expand All @@ -154,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
Expr(:(::), a,
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
nothing))
push!(where_types, :($vartype <: $type))
if !isnothing(meta) && haskey(meta, VariableUnit)
push!(where_types, vartype)
else
push!(where_types, :($vartype <: $type))
end
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
end
if dict[varclass] isa Vector
Expand All @@ -166,7 +176,7 @@ end

function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
type::Type = Real)
type::Type = Real, meta = Dict{DataType, Expr}())
metatypes = [(:connection_type, VariableConnectType),
(:description, VariableDescription),
(:unit, VariableUnit),
Expand All @@ -186,29 +196,31 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
a::Symbol => begin
var = generate_var!(dict, a, varclass; indices, type)
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
return var, def, Dict()
end
Expr(:(::), a, type) => begin
type = getfield(mod, type)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
end
Expr(:(::), Expr(:call, a, b), type) => begin
type = getfield(mod, type)
def = _type_check!(def, a, type, varclass)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
end
Expr(:call, a, b) => begin
var = generate_var!(dict, a, b, varclass, mod; indices, type)
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
return var, def, Dict()
end
Expr(:(=), a, b) => begin
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type)
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
if dict[varclass] isa Vector
dict[varclass][1][getname(var)][:default] = def
else
Expand All @@ -231,9 +243,9 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
return var, def, Dict()
end
Expr(:tuple, a, b) => begin
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; type)
meta = parse_metadata(mod, b)
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; type, meta)
if meta !== nothing
for (type, key) in metatypes
if (mt = get(meta, key, nothing)) !== nothing
Expand All @@ -253,7 +265,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
Expr(:ref, a, b...) => begin
indices = map(i -> UnitRange(i.args[2], i.args[end]), b)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types;
def, indices, type)
def, indices, type, meta)
end
_ => error("$arg cannot be parsed")
end
Expand Down Expand Up @@ -611,16 +623,58 @@ function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_
push!(exprs, ex)
end

function convert_units(varunits::DynamicQuantities.Quantity, value)
DynamicQuantities.ustrip(DynamicQuantities.uconvert(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(
varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where {T}
DynamicQuantities.ustrip.(DynamicQuantities.uconvert.(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
Unitful.ustrip.(varunits, value)
end

function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
vv, def, metadata_with_exprs = parse_variable_def!(
dict, mod, arg, varclass, kwargs, where_types)
name = getname(vv)

varexpr = quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
varexpr = if haskey(metadata_with_exprs, VariableUnit)
unit = metadata_with_exprs[VariableUnit]
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
try
$setdefault($vv, $convert_units($unit, $name))
catch e
if isa(e, $(DynamicQuantities.DimensionError)) ||
isa(e, $(Unitful.DimensionError))
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
elseif isa(e, MethodError)
error("No or invalid units provided for \'" * string(:($$vv)) *
"\'")
else
rethrow(e)
end
end
end
end
else
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
end
end
end

Expand Down
28 changes: 28 additions & 0 deletions test/dq_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], ModelingToolkit.t_nounits, [S], [β, γ])

@mtkmodel ParamTest begin
@parameters begin
a, [unit = u"m"]
end
@variables begin
b(t), [unit = u"kg"]
end
end

@named sys = ParamTest()

@named sys = ParamTest(a = 3.0u"cm")
@test ModelingToolkit.getdefault(sys.a) 0.03

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]
32 changes: 17 additions & 15 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using ModelingToolkit, Test
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
get_systems, get_ps, getdefault, getname, readable_code,
scalarize, symtype, VariableDescription, RegularConnector
scalarize, symtype, VariableDescription, RegularConnector,
get_unit
using URIs: URI
using Distributions
using DynamicQuantities, OrdinaryDiffEq
Expand Down Expand Up @@ -53,8 +54,9 @@ end
end
end

@named p = Pin(; v = π)
@test getdefault(p.v) == π
@named p = Pin(; v = π * u"V")

@test getdefault(p.v) π
@test Pin.isconnector == true

@mtkmodel OnePort begin
Expand All @@ -76,7 +78,6 @@ end

@test OnePort.isconnector == false

resistor_log = "$(@__DIR__)/logo/resistor.svg"
@mtkmodel Resistor begin
@extend v, i = oneport = OnePort()
@parameters begin
Expand Down Expand Up @@ -105,14 +106,14 @@ end
@parameters begin
C, [unit = u"F"]
end
@extend OnePort(; v = 0.0)
@extend OnePort(; v = 0.0u"V")
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
@equations begin
D(v) ~ i / C
end
end

@named capacitor = Capacitor(C = 10, v = 10.0)
@named capacitor = Capacitor(C = 10u"F", v = 10.0u"V")
@test getdefault(capacitor.v) == 10.0

@mtkmodel Voltage begin
Expand All @@ -127,9 +128,9 @@ end

@mtkmodel RC begin
@structural_parameters begin
R_val = 10
C_val = 10
k_val = 10
R_val = 10u"Ω"
C_val = 10u"F"
k_val = 10u"V"
end
@components begin
resistor = Resistor(; R = R_val)
Expand All @@ -147,9 +148,9 @@ end
end
end

C_val = 20
R_val = 20
res__R = 100
C_val = 20u"F"
R_val = 20u"Ω"
res__R = 100u"Ω"
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
prob = ODEProblem(rc, [], (0, 1e9))
sol = solve(prob, Rodas5P())
Expand All @@ -160,11 +161,12 @@ resistor = getproperty(rc, :resistor; namespace = false)
@test getname(rc.resistor.R) === getname(resistor.R)
@test getname(rc.resistor.v) === getname(resistor.v)
# Test that `resistor.R` overrides `R_val` in the argument.
@test getdefault(rc.resistor.R) == res__R != R_val
@test getdefault(rc.resistor.R) * get_unit(rc.resistor.R) == res__R != R_val
# Test that `C_val` passed via argument is set as default of C.
@test getdefault(rc.capacitor.C) == C_val
@test getdefault(rc.capacitor.C) * get_unit(rc.capacitor.C) == C_val
# Test that `k`'s default value is unchanged.
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val][:value]
@test getdefault(rc.constant.k) * get_unit(rc.constant.k) ==
eval(RC.structure[:kwargs][:k_val][:value])
@test getdefault(rc.capacitor.v) == 0.0

@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==
Expand Down
28 changes: 28 additions & 0 deletions test/units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])

@mtkmodel ParamTest begin
@parameters begin
a, [unit = u"m"]
end
@variables begin
b(t), [unit = u"kg"]
end
end

@named sys = ParamTest()

@named sys = ParamTest(a = 3.0u"cm")
@test ModelingToolkit.getdefault(sys.a) 0.03

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

0 comments on commit e037dfe

Please sign in to comment.