diff --git a/Project.toml b/Project.toml index 591f1ba..8874b91 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleExpressions" uuid = "deba94f7-f32a-40ad-b45e-be020a5ded2f" authors = ["jverzani and contributors"] -version = "1.1.8" +version = "1.1.9" [deps] CallableExpressions = "391672e0-bbe4-4ab4-8bc9-b89a79cbc2f0" diff --git a/docs/src/index.md b/docs/src/index.md index 2b74bf5..d759c5a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -104,7 +104,9 @@ The basic syntax for substitution is: The use of `:` to indicate the remaining value is borrowed from Julia's array syntax; it can also be either `nothing` or `missing`. -For evaluation and substitution using positional arguments, all instances of symbolic variables and all instances of symbolic parameters are treated identically. To work with multiple symbolic parameters or variables, `replace` can be used to substitute in values for a specific variable. +For evaluation and substitution using positional arguments, all instances of symbolic variables and all instances of symbolic parameters are treated identically. + +To work with multiple symbolic parameters or variables, `replace` can be used to substitute in values for a specific variable. * `replace(ex, args::Pair...)` to substitute in for either a variable, parameter, expression head, or symbolic expression (possibly with a wildcard). The pairs are specified as `variable_name => replacement_value`. * `ex(args::Pair...)` redirects to `replace(ex, args::Pair...)` @@ -113,7 +115,7 @@ To illustrate, two or more variables can be used, as here: ```@example expressions @symbolic x -@symbolic y # both symbolic variables +@symbolic y # or SimpleExpressions.@symbolic_variables x y u = x^2 - y^2 ``` @@ -131,9 +133,10 @@ v = replace(u, x=>1, y=>2) # the symbolic value ((1^2)-(2^2)) v() # evaluates to -3 ``` -The `replace` method is a bit more involved than illustrate. The `key => value` pairs have different dispatches depending on the value of the key. Above, the key is a `SymbolicVariable`, but the key can be +The `replace` method is a bit more involved than illustrated. The `key => value` pairs have different dispatches depending on the value of the key. Above, the key is a `SymbolicVariable`, but the key can be: * A `SymbolicVariable` or `SymbolicParameter` in which case the simple substitution is applied, as just illustrated. + * A function, like `sin`. In this case, a matching operation head is replaced by the replacement head. Eg. `sin => cos` will replace a `sin` call with a `cos` call. ```@example expressions @@ -148,16 +151,27 @@ v = 1 + (x+1)^1 + 2*(x+1)^2 + 3*(x+1)^3 replace(v, x+1 => x) ``` - -* A symbolic expression *with* a *wildcard*. The **special** symbol `⋯`, when made into a symbolic variable via `@symbolic ⋯` (where ` ⋯` is entered as `\cdots[tab]`) is treated as a wildcard for matching purposes. The `⋯` can be used in the replacement. +* A symbolic expression *with* a *wildcard*. Wildcards have a naming convention using trailing underscores. One matches one value; two matches one or more values; three match 0, 1, or more values. In addition, the **special** symbol `⋯` (entered with `\cdots[tab]` is wild. ```@example expressions v = log(1 + x) + log(1 + x^2/2) -@symbolic ⋯ # create wildcard -replace(v, log(1 + ⋯) => log1p(⋯)) +@symbolic x_ +replace(v, log(1 + x_) => log1p(x_)) # log1p(x) + log1p((x ^ 2) / 2) ``` +Substitution uses `match(pattern, subject)` for expression matching with wildcards: +```@example expressions +subject, pattern = log(1 + x^2/2), log(1+x_) +ms = match(pattern, subject) +``` + +The return value is `nothing` (for no match) or a collection of valid substitutions. Substituting one into the pattern should return the subject: + +```@example expressions +σ = first(ms) +pattern(σ...) +``` ## Symbolic containers @@ -172,7 +186,7 @@ u(2, (1,2,3,4)) # 49 This is relatively untested and almost certainly not fully featured. For example, only evaluation is allowed, not substitution (using `:`): -``` +```@example expressions @symbolic x a u = sum(ai * x^(i-1) for (i,ai) in enumerate(a)) u(2, [1,2,3]) @@ -249,7 +263,7 @@ x0 = 2 x0 - u(x0) / du(x0) ``` -Here the product rule is used: +Here the application of the product rule can be seen: ```@example expressions u = D(exp(x) * (sin(3x) + sin(101x)), x) @@ -257,4 +271,4 @@ u = D(exp(x) * (sin(3x) + sin(101x)), x) #### Simplification -No simplification is done so the expressions can quickly become unwieldy. There is `TermInterface` support, so--in theory--rewriting of expressions, as is possible with the `Metatheory.jl` package. The scaffolding is in place, but waits for the development version to be tagged. +No simplification is done so the expressions can quickly become unwieldy. There is `TermInterface` support, so--in theory--rewriting of expressions, as is possible with the `Metatheory.jl` package, is supported. The scaffolding is in place, but waits for the development version to be tagged. diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index 7b4d837..b26c61a 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -16,26 +16,29 @@ using CallableExpressions import TermInterface import TermInterface: iscall, operation, arguments, sorted_arguments, maketerm, is_operation, metadata +using Combinatorics using CommonEq -# export ≪, ≦, Eq, ⩵, ≶, ≷, ≫, ≧ # \ll, \leqq, \Equal,\lessgtr, \gtrless, \gg,\geqq - export @symbolic include("types.jl") include("constructors.jl") +include("decl.jl") include("equations.jl") include("terminterface.jl") -#include("metatheory.jl") include("ops.jl") include("show.jl") include("introspection.jl") include("call.jl") +include("matchpy.jl") include("replace.jl") include("comparison.jl") include("generators.jl") include("scalar-derivative.jl") -include("simplify.jl") include("solve.jl") +#include("simplify.jl") +#include("metatheory.jl") + + end diff --git a/src/constructors.jl b/src/constructors.jl index def0852..5fdae34 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -18,7 +18,7 @@ To call a symbolic expression regular call notation with positional arguments ar ## Example -```julia +```@example symbolic using SimpleExpressions @symbolic x p u = x^5 - x - 1 @@ -36,7 +36,7 @@ u(2, [1,2]) # 6 call is u(x, p) Calling with `nothing`, `missing`, or `:` in a slot *substitutes* in the specified value leaving a symbolic expression, possibly with no variable or parameter. -```julia +```@example symbolic @symbolic x p u = cos(x) - p*x u(nothing, 2) # cos(x) - 2 * x @@ -55,13 +55,13 @@ A symbolic equation, defined through `~`, may also be used to specify a left- an The main use is as an easier-to-type replacement for anonymous functions, though with differences: -```julia +```@example symbolic 1 |> sin(x) |> x^2 # 0.708… from sin(1)^2 u = cos(x) - p*x 2 |> u(:, 3) # -6.4161…, a alternative to u(2,3) ``` -```julia +```@example symbolic map(x^2, (1, 2)) # (1,4) ``` @@ -115,7 +115,7 @@ Using this is a convenience for *simple* cases. It is easy to run into idiosyncr Unlike functions, expressions are defined with variables at the time of definition, not when called. For example, with a clean environment: -```julia +```@example symbolic @symbolic x u = m*x + b # errors, `m` not defined f(x) = m*x + b # ok @@ -133,7 +133,7 @@ f(3) # computing 3 * 3 + 4, using values of `m` and `b` when called Though one can make different symbolic variables, the basic call notation by position treats them as the same: -```julia +```@example symbolic @symbolic x @symbolic y # both x, y are `SymbolicVariable` type u = x + 2y @@ -142,13 +142,13 @@ u(3) # 9 coming from 3 + 2*(3) However, this is only to simplify the call interface. Using *keyword* arguments allows evaluation with different values: -```julia +```@example symbolic u(;x=3, y=2) # 7 ``` Using `replace`, we have: -```julia +```@example symbolic u(x=>3, y=>2) # 3 + (2 * 2); evaluate with u(x=>3, y=>2)() ``` @@ -158,7 +158,7 @@ The underlying `CallableExpressions` object is directly called in the above mann The variables may be used as placeholders for containers, e.g. -```julia +```@example symbolic u = sum(xi*pi for (xi, pi) in zip(x,p)) u((1,2),(3,4)) # 11 ``` @@ -168,7 +168,7 @@ u((1,2),(3,4)) # 11 Broadcasting a function call works as expected -```julia +```@example symbolic @symbolic x u = x^2 u.((1,2)) # (1, 4) @@ -176,7 +176,7 @@ u.((1,2)) # (1, 4) Symbolic expressions can also be constructed that will broadcast the call -```julia +```@example symbolic u = x.^2 .+ sin.(p) u((1,2),3) diff --git a/src/decl.jl b/src/decl.jl new file mode 100644 index 0000000..8980511 --- /dev/null +++ b/src/decl.jl @@ -0,0 +1,184 @@ +## constructors +## copied from SymPyCore: /src/decl.jl. +## Contributed by @matthieubulte to SymPy pr #419. +""" + @symbolic_variables w x[1:3] y() z=>"𝑧" Ω::isinteger + +Define multiple symbolic variables or symbolic functions. Guards are ignored. + +Not exported. +""" +macro symbolic_variables(xs...) + # If the user separates declaration with commas, the top-level expression is a tuple + if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple + _gensyms(xs[1].args...) + elseif length(xs) > 0 + _gensyms(xs...) + end +end + +function _gensyms(xs...) + asstokw(a) = Expr(:kw, esc(a), true) + + # Each declaration is parsed and generates a declaration using `symbols` + symdefs = map(xs) do expr + decl = parsedecl(expr) + symname = sym(decl) + symname, gendecl(decl) + end + syms, defs = collect(zip(symdefs...)) + + # The macro returns a tuple of Symbols that were declared + Expr(:block, defs..., :(tuple($(map(esc,syms)...)))) +end + + +# The map_subscripts function is stolen from Symbolics.jl +const IndexMap = Dict{Char,Char}( + '-' => '₋', + '0' => '₀', + '1' => '₁', + '2' => '₂', + '3' => '₃', + '4' => '₄', + '5' => '₅', + '6' => '₆', + '7' => '₇', + '8' => '₈', + '9' => '₉') + +function map_subscripts(indices) + str = string(indices) + join(IndexMap[c] for c in str) +end + +# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later. +abstract type VarDecl end + +struct SymDecl <: VarDecl + sym :: Symbol +end + +struct NamedDecl <: VarDecl + name :: String + rest :: VarDecl +end + +struct FunctionDecl <: VarDecl + rest :: VarDecl +end + +struct TensorDecl <: VarDecl + ranges :: Vector{AbstractRange} + rest :: VarDecl +end + +struct AssumptionsDecl <: VarDecl + assumptions :: Vector{Symbol} + rest :: VarDecl +end + +# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol +function gendecl(x::VarDecl) + asstokw(a) = Expr(:kw, esc(a), true) + val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...))) + :($(esc(sym(x))) = $(genreshape(val, x))) +end + +# Transform an expression in a Decl struct +function parsedecl(expr) + # @syms x + if isa(expr, Symbol) + return SymDecl(expr) + + elseif isa(expr, NTuple{N, T} where {N,T}) + return SymDecl.(expr) + + # @syms x::assumptions, where assumption = assumptionkw | (assumptionkw...) + elseif isa(expr, Expr) && expr.head == :(::) + symexpr, assumptions = expr.args + assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args + return AssumptionsDecl(assumptions, parsedecl(symexpr)) + + # @syms x=>"name" + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>) + length(expr.args) == 3 || parseerror() + isa(expr.args[3], String) || parseerror() + + expr, strname = expr.args[2:end] + return NamedDecl(strname, parsedecl(expr)) + + # @syms x() + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>) + length(expr.args) == 1 || parseerror() + return FunctionDecl(parsedecl(expr.args[1])) + + # @syms x[1:5, 3:9] + elseif isa(expr, Expr) && expr.head == :ref + length(expr.args) > 1 || parseerror() + ranges = map(parserange, expr.args[2:end]) + return TensorDecl(ranges, parsedecl(expr.args[1])) + else + parseerror() + end +end + +function parserange(expr) + range = eval(expr) + isa(range, AbstractRange) || parseerror() + range +end + +sym(x::SymDecl) = x.sym +sym(x::NamedDecl) = sym(x.rest) +sym(x::FunctionDecl) = sym(x.rest) +sym(x::TensorDecl) = sym(x.rest) +sym(x::AssumptionsDecl) = sym(x.rest) + +ctor(::SymDecl) = :SymbolicVariable +ctor(x::NamedDecl) = ctor(x.rest) +ctor(::FunctionDecl) = :SymbolicFunction +ctor(x::TensorDecl) = ctor(x.rest) +ctor(x::AssumptionsDecl) = :GuardedSymbolicVariable + +assumptions(::SymDecl) = [] +assumptions(x::NamedDecl) = assumptions(x.rest) +assumptions(x::FunctionDecl) = assumptions(x.rest) +assumptions(x::TensorDecl) = assumptions(x.rest) +assumptions(x::AssumptionsDecl) = x.assumptions + +# Reshape is not used by most nodes, but TensorNodes require the output to be given +# the shape matching the specification. For instance if @syms x[1:3, 2:6], we should +# have size(x) = (3, 5) +genreshape(expr, ::SymDecl) = expr +genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest) +genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest) +genreshape(expr, x::TensorDecl) = let + shape = tuple(length.(x.ranges)...) + :(reshape(collect($(expr)), $(shape))) +end +genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest) + +# To find out the name, we need to traverse in both directions to make sure that each node can get +# information from parents and children about possible name. +# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl +# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or +# based on the SymDecl leaf. +name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym)) +name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name) +name(x::FunctionDecl, parentname) = name(x.rest, parentname) +name(x::AssumptionsDecl, parentname) = name(x.rest, parentname) +name(x::TensorDecl, parentname) = let + basename = name(x.rest, parentname) + # we need to double reverse the indices to make sure that we traverse them in the natural order + namestensor = map(Iterators.product(x.ranges...)) do ind + sub = join(map(map_subscripts, ind), "_") + string(basename, sub) + end + return tuple(namestensor[:]...) + join(namestensor[:], ", ") +end + +function parseerror() + error("Incorrect @syms syntax. Try `@syms x::(real,positive)=>\"x₀\" y() z::complex n::integer` for instance.") +end diff --git a/src/matchpy.jl b/src/matchpy.jl new file mode 100644 index 0000000..2aa8364 --- /dev/null +++ b/src/matchpy.jl @@ -0,0 +1,514 @@ +# implement algorithm of matchpy paper through Ch. 3 +# Non-linear Associative-Commutative Many-to-One Pattern Matching with Sequence Variables by Manuel Krebber + +# 𝑋 variables: regular, [star, plus] +# 𝐹 function heads + +# split symbolic objects into +# 𝐹₀ 0-arity expressions +# 𝐿 all symbolic variables +# 𝑋 wildcard expressions which split into +# Xʳᵉᵍᵘˡᵃʳ regular -- `_is_Wild` +# 𝑋Xᵖˡᵘˢ plus variables -- `_is_Plus` +# Xˢᵗᵃʳ star variables -- `_is_Star` +_is_𝐹₀(::Any) = false # 𝐹ₙ is arity of function; this is no function +_is_𝐿(x::Any) = false # +_is_Wild(x::Any) = false +_is_Plus(x::Any) = false # atleast one +_is_Star(x::Any) = false +_is_𝑋(x) = _is_Wild(x) || _is_Plus(x) || _is_Star(x) # + +# predicates +isassociative(::Any) = false +iscommutative(::Any) = false + +isassociative(::typeof(+)) = true +isassociative(::typeof(*)) = true + +iscommutative(::typeof(+)) = true +iscommutative(::typeof(*)) = true + +# ExpressionType = SymbolicExpression + +## --------------------------------- +## only TermInterface below this line + +## matchpy + +# Δ could use Dict for this + +# σ△σ′ +function iscompatible(σ, σ′) + isnothing(σ) && return true + isnothing(σ′) && return false + for (s,p) ∈ σ + for (s′, p′) ∈ σ′ + s == s′ && p != p′ && return false + end + end + true +end + +# σ⊔σ′ +function union_match(σ, σ′) + isnothing(σ) && return σ′ + for (s′, p′) ∈ σ′ + any(s′ == s for (s,p) ∈ σ) && continue + if _is_𝑋(s′) + σ = TupleTools.vcat(σ, (s′ => p′,)) + end + end + σ +end + +# {σ⊔σ′ |σ∈Θ∧σ△σ′} +function union_matches(Θ, σ′) + isnothing(Θ) && return (σ′, ) + in(σ′, Θ) && return Θ + out = tuple((union_match(σ, σ′) for σ ∈ Θ + if iscompatible(σ, σ′))...) + out +end + +## return iterator -- doesn't seem more performant +function _union_matches(Θ, σ′) + isnothing(Θ) && return Iterators.rest((σ′,), 1) + Iterators.map(Iterators.filter(Θ) do σ + iscompatible(σ, σ′) + end ) do σ + union_match(σ, σ′) + end +end + +# Θ ∪ Θ′ +function union_match_sets(Θ, Θ′) + Θ == ∅ && return Θ′ + Θ′ == ∅ && return Θ + Θ′′ = filter(!in(Θ), Θ′) + TupleTools.vcat(Θ, Θ′′) +end + +# return substitution tuple (p1 => s1, p2 => s2, ...) possibly empty () +# or return nothing if no match +function SyntacticMatch(s, p, σ=nothing) + _is_𝑋(p) && return (p => s,) + _is_𝐿(p) && return s == p ? () : nothing + s == p && return () + _is_𝐹₀(p) && return nothing + + opₛ, opₚ = operation(s), operation(p) + opₛ != opₚ && return nothing + + argsₛ, argsₚ = arguments(s), arguments(p) + length(argsₛ) == length(argsₚ) || return nothing + + for (si,pi) ∈ zip(argsₛ, argsₚ) + σ′ = SyntacticMatch(si, pi, σ) + (isnothing(σ′) || !iscompatible(σ, σ′)) && return nothing + σ = union_match(σ, σ′) + end + + return σ +end + + +# σ is nothing or a substitution tuple, possibly () +# Θ is empty, (), or +∅ = () # is not ((),) + +# fₐ is +,*, or nothing +function MatchOneToOne(ss::Tuple, p, fₐ=nothing, Θ=((),)) + n = length(ss) + if _is_𝐿(p) && !_is_𝑋(p) # 𝐹₀ -- not a SymbolicExpression + n == 1 && p == only(ss) && return Θ + elseif _is_Wild(p) && isnothing(fₐ) + σ′ = (p => first(ss),) + n == 1 && return union_matches(Θ, σ′) + elseif _is_𝑋(p) + if _is_𝑋(p) && !isnothing(fₐ) + σ′ = (p => maketerm(ExpressionType, fₐ, ss, nothing),) + else + σ′ = (p => ss,) + end + if _is_Star(p) || n ≥ 1 + return union_matches(Θ, σ′) + end + elseif n == 1 + s = only(ss) + hₚ, hₛ = operation(p), operation(s) + if hₚ == hₛ + ss = arguments(s) + ps = arguments(p) + fₐ′ = isassociative(hₚ) ? hₚ : nothing + if iscommutative(fₐ′) + return MatchCommutativeSequence(ss, ps, fₐ′, Θ) + else + return MatchSequence(ss, ps, fₐ′, Θ) + end + end + end + return ∅ +end + + +function MatchSequence(ss, ps, fₐ=nothing, Θ=((),)) + n,m = length(ss), length(ps) + nstar = sum(_is_Star(p) for p in ps) + m - nstar > n && return ∅ + nplus = sum(_is_Plus(p) for p in ps) + if isassociative(fₐ) + nplus = nplus + sum(_is_Wild(p) for p in ps) + end + nfree = n - m + nstar + nseq = nstar + nplus + Θᵣ = ∅ + + for ks ∈ Base.Iterators.product((0:nfree for _ in 1:nseq)...) + (!isempty(ks) && sum(ks) != nfree) && continue + i, j = 1, 1 # 0,0?? + Θ′ = Θ + for (l,pl) ∈ enumerate(ps) + lsub = 1 + if (_is_Plus(pl) || _is_Star(pl)) || + (_is_Wild(pl) && !isnothing(fₐ)) + kj = isempty(ks) ? 1 : ks[j] + lsub = lsub + kj + if _is_Star(pl) + lsub = lsub - 1 + end + j = j + 1 + end + ss′ = ss[i:(i+lsub-1)] # note -1 here + Θ′ = MatchOneToOne(ss′, pl, fₐ, Θ′) + Θ′ == ∅ && break + i = i + lsub + end + Θᵣ = union_match_sets(Θᵣ, Θ′) + end + return Θᵣ +end + +# XXX still shaky +function MatchCommutativeSequence(ss, ps, fₐ=nothing, Θ=((),)) + debug = false + debug && @show :matchcomm, ss, ps, fₐ, Θ + + # constant patterns + out = _match_constant_patterns(ss, ps) + isnothing(out) && return ∅ + ss, ps = out + + debug && @show :constant, ss, ps, Θ + + # matched variables first + # for each σ we might get a different set of ss, ps after + # this needs to branch out + Θc = ∅ + + for σ ∈ Θ + out = _match_matched_variables(ss, ps, σ) + out == ∅ && return ∅ + ss, ps = out + + debug && @show :matched, ss, ps, σ + + out = _match_non_variable_patterns(ss, ps, fₐ, σ) + out == ∅ && return ∅ + ss, ps, Θ′ = out + + debug && @show :non_variable, ss, ps, Θ′ + + for σ′ ∈ Θ′ + ## then repeat matched variable ... + out = _match_matched_variables(ss, ps, σ′) + out == ∅ && return out + ss, ps = out + + debug && @show :matched2, ss, ps, σ′ + + Θ′ = (σ′,) + # regular variables p ∈ 𝑋₀ and then sequence variables + if isempty(ps) + σ′ != () && (Θc = union_match_sets(Θc, Θ′)) + else + for out in _match_regular_variables(ss, ps, fₐ, σ′) + debug && @show :regular, out + ss, ps, σ = out # XX \sigma or Θ + Θ′ = (σ, ) + if length(ps) > 0 + Θ′ = _match_sequence_variables(ss, ps, fₐ, σ) + end + Θc = union_match_sets(Θc, Θ′) + end + end + end + end + + return Θc + +end + +function _check_matched_variables(σ, ss, ps) + # check for each match in σ + # there are as many subjects as needed for the match + for (p,s) ∈ σ + # how many times does s appear in pattern + inds = findall(==(s), ss) + n = length(inds) + inds = findall(==(p), ps) + length(inds) >= n || return false + end + return true +end + +function _match_constant_patterns(ss, ps) + pred(a) = any(any(_is_𝑋(u) for u in s) for s in free_symbols(a)) + Pconst = filter(!pred, ps) + for p ∈ Pconst + p in ss || return nothing + ss = filter(!=(p), ss) + end + ps = filter(p -> p ∉ Pconst, ps) + (ss, ps) +end + +# trims down ss, ps +function _match_matched_variables(ss, ps, σ) + # subtract from, ps, ss previously matched variables + (isnothing(σ) || isempty(σ)) && return (ss, ps) + for (p,s) ∈ σ + for _ in 1:count(==(p), ps) + # delete s from ss or return nothhing + itr = isa(s, Tuple) ? s : (s,) + for si ∈ itr + i = findfirst(==(si), ss) + isnothing(i) && return nothing + ss = tuple((v for (j,v) ∈ enumerate(ss) if j != i)...) + end + end + end + ps = tuple((v for v in ps if v ∉ first.(σ))...) + ss, ps +end + +# return () or (ss, ps, Θ) +function _match_non_variable_patterns(ss, ps, fc=nothing, σ=()) + ps′′, ps′ = tuplesplit(!iscall, ps) + length(ps′) == 0 && return (ss, ps, (σ,)) + + ss′′, ss′ = tuplesplit(!iscall, ss) + length(ps′) == length(ss′) || return ∅ + + Θᵣ = ∅ + for inds ∈ Combinatorics.permutations(1:length(ss′)) + ss′′′ = ss′[inds] + Θ′ = (σ,) + for (s,p) ∈ zip(ss′′′, ps′) + operation(s) == operation(p) || continue + Θ′ = MatchSequence(arguments(s), arguments(p), fc, Θ′) + Θ′ == ∅ && continue + end + Θ′ == ∅ && continue + Θᵣ = union_match_sets(Θᵣ, Θ′) + end + Θᵣ == ∅ && return ∅ + ss′′, ps′′, Θᵣ +end + +# return container of ss, ps, sigma +function _match_regular_variables(ss, ps, fc=nothing, σ = ()) + # fₐ is commutative, maybe associative + isassociative(fc) && return ((ss, ps, σ),) + + ps_reg, ps′′ = tuplesplit(_is_Wild, ps) + isempty(ps_reg) && return ((ss, ps, σ),) + + if length(ps_reg) < length(ss) + if ps_reg == ps + # can't match, not enough + return () + end + end + + dp = _countmap(ps_reg) + ds = _countmap(ss) + + out = _split_take(ds, dp) + out = filter(ab -> iscompatible(first(ab), σ), out) + out = [(union_match(σ, σ′), ds) for (σ′,ds) ∈ out] + # return ss, ps, σ for each in out + tuple( + ((_uncountmap(ds), ps′′, σ) for (σ, ds) ∈ out)... + ) + +end + +# different ways to grab the pie +function _split_take(ds, dp) + out = [] + n = length(ds) + k = length(dp) + for inds in Iterators.product((1:n for _ in 1:k)...) + ds′ = copy(ds) + σ = () + for (i, (p, np)) ∈ zip(inds, (dp)) + s, ns = ds′[i] + np > ns && (σ = (); break) # won't fit + ds′[i] = s => (ns - np) + σ = union_match(σ, ((p => s),)) + end + σ == () && continue + push!(out, (σ, ds′)) + end + out +end + + +function _match_sequence_variables(ss, ps, fc, σ = ()) + λ = isassociative(fc) ? (x -> _is_Wild(x) || _is_Plus(x)) : + _is_Plus + vs = tuplesplit(λ, ps) + length(first(vs)) > length(ss) && return () # too many plus variables + + ds = _countmap(ss) + dplus, dstar = _countmap(first(vs)), _countmap(last(vs)) + + Θ = brute_force_enumeration(ds, dplus, dstar, fc, σ) + + return Θ +end + + + +# bruteforce enumeration of possible values (defn 3.1) +# working with tuples likely an issue +function brute_force_enumeration(ds, dplus, dstar, fₐ, σ′=()) + pluses = tuple((v for (k,v) in dplus)...) + stars = tuple((v for (k,v) in dstar)...) + ss = tuple((v for (k,v) in ds)...) + + vars = TupleTools.vcat(tuple(first.(dplus)...), tuple(first.(dstar)...)) + svars = tuple(first.(ds)...) + + n1, n2 = length(pluses), length(stars) + n = n1 + n2 + ks = TupleTools.vcat(pluses, stars) + i = ntuple((a) -> 0, Val(n)) + + Θ = () + h = isnothing(fₐ) ? identity : + ((as) -> _maketerm(fₐ, as)) + for u ∈ Iterators.product( + (Iterators.product((0:s for _ in 1:n)...) for s in ss)...) + all(sum(ui .* ks) == si for (ui,si) in zip(u, ss)) || continue + all(sum(ui[i] for ui in u) > 0 for i in 1:n1) || continue + σ = () + for (j, v) ∈ enumerate(vars) + vv = () + for (i,s) in enumerate(svars) + vi = ntuple((_) -> s, Val(u[i][j])) + vv = TupleTools.vcat(vv, vi) + end + if vv != () + σ = TupleTools.vcat(σ, (v => h(vv),)) + end + end + if iscompatible(σ′, σ) + σ = union_match(σ′, σ) + Θ = TupleTools.vcat(Θ, (σ,)) + end + end + Θ +end + +# need unit here +function _maketerm(fa, xs) + isempty(xs) && return + fa == (*) ? one(ExpressionType) : + fa == (+) ? zero(ExpressionType) : + () + maketerm(ExpressionType, fa, xs, nothing) +end + +## ----- + +""" + map_matched(ex, is_match, f) + +Traverse expression. If `is_match` is true, apply `f` to that part of expression tree and reassemble. + +Basically `CallableExpressions.expression_map_matched`. + +Not exported. +""" +map_matched(ex, is_match, f) = map_matched(Val(iscall(ex)), ex, is_match, f) +map_matched(::Val{false}, x, is_match, f) = is_match(x) ? f(x) : x +function map_matched(::Val{true}, x, is_match, f) + # copy of CallableExpressions.expression_map_matched(pred, mapping, u) + # but in SimpleExpressions domain + is_match(x) && return f(x) + iscall(x) || return x + children = map_matched.(arguments(x), is_match, f) + maketerm(ExpressionType, operation(x), children, metadata(x)) +end + + + +## ----- Replace ----- +## exact replacement +function _replace_exact(ex, p, q) + map_matched(ex, ==(p), _ -> q) +end + +# replace expression head u with v +function _replace_expression_head(ex, u, v) + !iscall(ex) && return ex + args′ = (_replace_expression_head(a, u, v) for a ∈ arguments(ex)) + op = operation(ex) + λ = op == u ? v : op + ex = maketerm(ExpressionType, λ, args′, nothing) +end + +## Replacement of arguments +function _replace_arguments(ex, u, v) + iscall(ex) || return (ex == u ? v : ex) + + m = match(u, ex) + if !isnothing(m) + m == () && return v + + σ = first(m) + σ == () && return v + return v(σ...) + end + + # peel off + op, args = operation(ex), arguments(ex) + args′ = _replace_arguments.(args, (u,), (v,)) + + return maketerm(ExpressionType, op, args′, nothing) +end + + +## ----- + +## utils +function _countmap(x) + d = IdDict() + [(d[xi] = get(d, xi, 0) + 1) for xi in x] + return [k => v for (k,v) ∈ d] +end +function _uncountmap(dx) + TupleTools.vcat((tuple((k for _ in 1:v)...) for (k,v) in dx)...) +end + +tuplesplit(pred, t) = (t = filter(pred,t), f=filter(!pred, t)) + +# take b out of a, error if b has elements not in a or too many +function tuplediff(as, bs) + for b in bs + i = findfirst(==(b), as) + as = tuple((as[j] for j in eachindex(as) if j != i)...) + end + as +end diff --git a/src/metatheory.jl b/src/metatheory.jl index a7808d3..036804d 100644 --- a/src/metatheory.jl +++ b/src/metatheory.jl @@ -1,3 +1,33 @@ +## ----- Interface +""" + simplify(ex) + +Simplify expression using `Metatheory.jl` and rules on loan from `SymbolicUtils.jl`. +""" +function simplify() +end + +""" + expand(ex) + +Expand terms in an expression using `Metatheory.jl` +""" +function expand() +end + +# some default definitions +# we extend to SymbolicExpression in the Metatheroy extension +for fn ∈ (:simplify, :expand, + :canonicalize, :powsimp, :trigsimp, :logcombine, + :expand_trig, :expand_power_exp, :expand_log) + @eval begin + $fn(ex::AbstractSymbolic) = ex + $fn(eq::SymbolicEquation) = SymbolicEquation($fn.(eq)...) + end +end + +## --------------- + import Combinatorics: combinations, permutations using Metatheory diff --git a/src/ops.jl b/src/ops.jl index 84ca898..6a5766e 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -29,22 +29,26 @@ for op ∈ (:/, ) import Base: $op Base.$op(x::AbstractSymbolic, y::AbstractSymbolic) = SymbolicExpression(StaticExpression((↓(x), ↓(y)), $op)) - Base.$op(x::AbstractSymbolic, y::Number) = $op(promote(x,y)...) + Base.$op(x::AbstractSymbolic, y::Number) = _isunit(*,y) ? x : $op(promote(x,y)...) Base.$op(x::Number, y::AbstractSymbolic) = $op(promote(x,y)...) end end ## arrange for *, + to be n-ary +_isunit(::typeof(+), y::Number) = iszero(y) +_isunit(::typeof(*), y::Number) = isone(y) + for op ∈ (:*, :+) @eval begin import Base: $op Base.$op(x::AbstractSymbolic, y::AbstractSymbolic) = SymbolicExpression(StaticExpression(TupleTools.vcat(_arguments($op,x), _arguments($op,y)), $op)) - Base.$op(x::AbstractSymbolic, y::Number) = $op(promote(x,y)...) - Base.$op(x::Number, y::AbstractSymbolic) = $op(promote(x,y)...) + Base.$op(x::AbstractSymbolic, y::Number) = _isunit($op, y) ? x : $op(promote(x,y)...) + Base.$op(x::Number, y::AbstractSymbolic) = _isunit($op, x) ? y : $op(promote(x,y)...) end end + Base.:-(x::AbstractSymbolic, y::AbstractSymbolic) = x + (-1)*y Base.:-(x::AbstractSymbolic, y::Number) = x + (-1)*y Base.:-(x::Number, y::AbstractSymbolic) = x + (-1)*y @@ -264,8 +268,8 @@ function Base.broadcasted(::typeof(Base.literal_pow), u, a::AbstractSymbolic, end - # simplifying operations +# XXX These are really in need of removal ## plus ⊕(x::SymbolicNumber,y::SymbolicNumber) = SymbolicNumber(x() + y()) function ⊕(x,y) @@ -282,6 +286,7 @@ function ⊖(x,y) return x - y end + ## times ⊗(x::SymbolicNumber,y::SymbolicNumber) = SymbolicNumber(x() * y()) function ⊗(x,y) diff --git a/src/replace.jl b/src/replace.jl index 271270d..ce2628d 100644 --- a/src/replace.jl +++ b/src/replace.jl @@ -1,3 +1,38 @@ +# implementation specific definitions needed for matching in matchpy +const ExpressionType = SymbolicExpression + +_is_𝐿(x::AbstractSymbolic) = isa(x, 𝐿) +_is_𝐹₀(x::AbstractSymbolic) = all(isempty(u) for u in free_symbols(x)) + +function _is_Wild(x::𝑉) # 1 + 𝑥 = string(Symbol(x)) + endswith(𝑥, "__") && return false + endswith(𝑥, "_") +end + +function _is_Plus(x::𝑉) # 1 or more + 𝑥 = string(Symbol(x)) + endswith(𝑥, "___") && return false + endswith(𝑥, "__") +end + +function _is_Star(x::SymbolicVariable) # 0, 1, or more + 𝑥 = string(Symbol(x)) + endswith(𝑥, "___") +end + +function _is_𝑋(x::SymbolicVariable) + 𝑥 = string(Symbol(x)) + endswith(𝑥, "_") +end + +# keep ⋯ as match so as not breaking +_is_Wild(x::SymbolicVariable{:⋯}) = true +_is_𝑋(x::SymbolicVariable{:⋯}) = true + +## ---- + + """ replace(ex::SymbolicExpression, args::Pair...) @@ -9,7 +44,7 @@ The replacement is specified using `variable => value`; these are processed left There are different methods depending on the type of key in the the `key => value` pairs specified: -* A symbolic variable is replaced by the right-hand side, like `ex(val,:)` +* A symbolic variable is replaced by the right-hand side, like `ex(val,:)`, though the latter is more performant * A symbolic parameter is replaced by the right-hand side, like `ex(:,val)` * A function is replaced by the corresponding specified function, as the head of the sub-expression * A sub-expression is replaced by the new expression. @@ -18,9 +53,14 @@ There are different methods depending on the type of key in the the `key => valu The first two are straightforward. -```julia +```@repl replace +julia> using SimpleExpressions + +julia> @symbolic x p +(x, p) + julia> ex = cos(x) - x*p -cos(x) - (x * p) +cos(x) + (-1 * x * p) julia> replace(ex, x => 2) == ex(2, :) true @@ -31,15 +71,14 @@ true The third, is illustrated by: -```julia -julia> replace(x + sin(x), sin => cos) -x + cos(x) - +```@repl replace +julia> replace(sin(x + sin(x + sin(x))), sin => cos) +cos(x + cos(x + cos(x))) ``` The fourth is similar to the third, only an entire expression (not just its head) is replaced -```{julia} +```@repl replace julia> ex = cos(x)^2 + cos(x) + 1 (cos(x) ^ 2) + cos(x) + 1 @@ -52,12 +91,12 @@ julia> replace(ex, cos(x) => u) Replacements occur only if an entire node in the expression tree is matched: -```julia +```@repl replace julia> u = 1 + x 1 + x -julia> replace(u + exp(-u), u => x) -1 + x + exp(-1 * x) +julia> replace(u + exp(-u), u => x^2) +1 + x + exp(-1 * (x ^ 2)) ``` (As this addition has three terms, `1+x` is not a subtree in the expression tree.) @@ -65,33 +104,31 @@ julia> replace(u + exp(-u), u => x) The fifth needs more explanation, as there can be wildcards in the expression. -The symbolic variable `⋯` (created with `@symbolic ⋯`, where `⋯` is formed by `\\cdots[tab]`) can be used as a wild card that matches the remainder of an expression tree. The replacement value can have `⋯` as a variable, in which case the identified values will be substituted. +Wildcards have a naming convention using trailing underscores. One matches one value; two matches one or more values; three match 0, 1, or more values. In addition, the **special** symbol `⋯` (entered with `\\cdots[tab]` is wild. -```julia -julia> @symbolic x p; @symbolic ⋯ -(⋯,) +```@repl replace +julia> @symbolic x p; @symbolic x_ +(x_,) + +julia> replace(cos(pi + x^2), cos(pi + x_) => -cos(x_)) +-1 * cos(x ^ 2) -julia> replace(cos(pi + x^2), cos(pi + ⋯) => -cos(⋯)) --1 * cos(x^2) ``` -```julia +```@repl replace julia> ex = log(sin(x)) + tan(sin(x^2)) log(sin(x)) + tan(sin(x ^ 2)) -julia> replace(ex, sin(⋯) => tan((⋯) / 2)) -log(tan(x / 2)) + tan(tan(x ^ 2 / 2)) +julia> replace(ex, sin(x_) => tan((x_) / 2)) +log(tan(x / 2)) + tan(tan((x ^ 2) / 2)) -julia> replace(ex, sin(⋯) => ⋯) +julia> replace(ex, sin(x_) => x_) log(x) + tan(x ^ 2) -julia> replace(x*p, (⋯) * x => ⋯) +julia> replace(x*p, (x_) * x => x_) p - ``` -(The wrapping of `(⋯)` in the last example is needed as the symbol parses as an infix operator.) - ## Picture The `AbstractTrees` package can print this tree-representation of the expression `ex = sin(x + x*log(x) + cos(x + p + x^2))`: @@ -121,7 +158,7 @@ The command wildcard expression `cos(x + ...)` looks at the part of the tree tha function Base.replace(ex::AbstractSymbolic, args::Pair...) for pr in args k,v = pr - ex = _replace(ex, k, v) + ex = _replace(ex, k, ↑(v)) end ex end @@ -138,19 +175,22 @@ end # _replace: basic dispatch in on `u` with (too) many methods # for shortcuts based on typeof `ex` -## u::SymbolicVariable +## u::SymbolicVariable **including** a wild card function _replace(ex::SymbolicExpression, u::SymbolicVariable, v) - pred = ==(↓(u)) - mapping = _ -> ↓(v) - ex = SymbolicExpression(expression_map_matched(pred, mapping, ↓(ex))) + ## intercept wildcards!!! + ex′, u′, v′ = map(↓, (ex, u, v)) + pred = ==(u′) + mapping = _ -> v′ + SymbolicExpression(expression_map_matched(pred, mapping, ex′)) end ## u::SymbolicParameter function _replace(ex::SymbolicExpression, u::SymbolicParameter, v) - pred = ==(↓(u)) - mapping = _ -> ↓(v) - ex = SymbolicExpression(expression_map_matched(pred, mapping, ↓(ex))) + ex′, u′, v′ = map(↓, (ex, u, v)) + pred = ==(u′) + mapping = _ -> v′ + SymbolicExpression(expression_map_matched(pred, mapping, ex′)) end @@ -159,175 +199,74 @@ _replace(ex::SymbolicParameter, u::SymbolicParameter, v) = ex == u ? ↑(v) : ex ## u::Function (for a head, keeping in mind this is not for SymbolicExpression) - # replace old head with new head in expression -_replace(ex::SymbolicNumber, u::Function, v) = ex -_replace(ex::SymbolicParameter, u::Function, v) = ex -_replace(ex::SymbolicVariable, u::Function, v) = ex - -function _replace(ex::SymbolicExpression, u::Function, v) - op, args = operation(ex), arguments(ex) - if op == u - op = v - end +function _replace(ex::AbstractSymbolic, u::𝐹, v) where + {𝐹 <: Union{Function, SymbolicFunction}} + _replace_expression_head(ex, u, v) +end - args′ = (_replace(a, u, v) for a ∈ args) +## u::SymbolicExpression, quite possibly having a wildcard - ex = maketerm(SymbolicExpression,op, args′, nothing) -end +# +# u is symbolic expression possibly wild card +_replace(ex::AbstractSymbolic, u::SymbolicExpression, v) = + _replace_arguments(ex, u, v) +""" + match(pattern, expression) -## u::SymbolicExpression, quite possibly having a wildcard +Match expression using a pattern with possible wildcards. Uses a partial implementation of *Non-linear Associative-Commutative Many-to-One Pattern Matching with Sequence Variables* by Manuel Krebber. -## We use ⋯ (`\\cdots[tab]`) for a single wildcard that should -## * take up remaining terms in `+` or `*` expressions -## * represent branches of an expression tree. -const WILD = SymbolicVariable(:(⋯)) - -has_WILD(ex::SymbolicNumber) = false -has_WILD(ex::SymbolicParameter) = false -has_WILD(ex::SymbolicVariable) = ex == WILD -function has_WILD(ex::SymbolicExpression) - for a ∈ arguments(ex) - has_WILD(a) && return true - end - return false -end +If there is no match: returns `nothing`. +If there is a match: returns a collection of substitutions (σ₁, σ₂, …) -- possibly empty -- with the property `pattern(σ...) == expression` is true. -# u is symbolic expression possibly wild card -_replace(ex::SymbolicNumber, u::SymbolicExpression, v) = ex -_replace(ex::SymbolicParameter, u::SymbolicExpression, v) = ex -_replace(ex::SymbolicVariable, u::SymbolicExpression, v) = ex - -function _replace(ex::SymbolicExpression, u::SymbolicExpression, v) - if !has_WILD(u) - # no wildcard so we must match expression tree completely - return _exact_replace(ex, u, v) - end - ## ⋯ There is a *wild* card for an expression match - m = match(u, ex) - !isnothing(m) && return has_WILD(v) ? _replace(v, WILD, m) : ↑(v) +Wildcards are just symbolic variables with a naming convention: use one trailing underscore to indicate a single match, two trailing underscores for a match of one or more, and three trailing underscores for a match on 0, 1, or more. - # peel off - op, args = operation(ex), arguments(ex) - args′ = _replace.(args, (u,), (v,)) +## Examples - return maketerm(AbstractSymbolic, op, args′, nothing) +```@repl +julia> using SimpleExpressions -end +julia> SimpleExpressions.@symbolic_variables a b x_ x__ x___ +(a, b, x_, x__, x___) -# return arguments fill out ⋯ or nothing if not a -# match in the expression tree -# this seems like the correct use of the generic -function Base.match(pat::AbstractSymbolic, ex::AbstractSymbolic) - has_WILD(pat) || return (pat == ex ? ex : nothing) - m = _ismatch(ex, pat) - return m -end +julia> p, s= x_*cos(x__), a*cos(2 + b) +(x_ * cos(x__), a * cos(2 + b)) -# ismatch wildcard -# return hasmatch: this matches or contains a match -# and expression/missing expression if a match, nothing if not -_ismatch(ex::AbstractSymbolic, u::SymbolicVariable) = ex == u ? u : nothing -_ismatch(ex::AbstractSymbolic, u::typeof(WILD)) = ex - -_ismatch(ex::SymbolicNumber, u::SymbolicExpression) = nothing -_ismatch(ex::SymbolicVariable, u::SymbolicExpression) = nothing -_ismatch(ex::SymbolicParameter, u::SymbolicExpression) = nothing - -function _ismatch(ex::SymbolicExpression, u::SymbolicExpression) - opₓ, opᵤ = operation(ex), operation(u) - opₓ == opᵤ || return nothing - argsₓ, argsᵤ = arguments(ex), arguments(u) - if opₓ == (+) || opₓ == (*) - asₓ, asᵤ = sort(collect(argsₓ)), sort(collect(argsᵤ)) - if WILD ∈ asᵤ - for a ∈ asᵤ - a == WILD && continue - a ∈ asₓ || return nothing - end - ex′ = maketerm(AbstractSymbolic, opₓ, _diff!(asₓ, asᵤ), nothing) - return ex′ - else - length(asₓ) == length(asᵤ) || return nothing - for (a,b) ∈ zip(asₓ, asᵤ) - a == b && continue - (!has_WILD(b) && a != b) && return nothing - matched, m = _ismatch(a, b) - matched && !isnothing(m) && return m - matched || return nothing - end - end - else - for (a,b) ∈ zip(argsₓ, argsᵤ) - if !(has_WILD(b)) - a == b || return nothing - end - end - for (a,b) ∈ zip(argsₓ, argsᵤ) - m = _ismatch(a, b) - return m - end - end - @show :shouldnt_be_here, ex, u - return missing -end +julia> Θ = match(p, s) +((x__ => 2 + b, x_ => a),) -# remove elements in xs′ that appear in xs but only once! -function _diff!(xs, xs′) - for i in eachindex(xs′) - i = only(indexin(xs′[i:i], xs)) - !isnothing(i) && deleteat!(xs, i) - end - xs -end +julia> σ = only(Θ) +(x__ => 2 + b, x_ => a) +julia> p(σ...) == s +true -""" - map_matched(ex, is_match, f) +julia> p, s = p = x_ + x__ + x___, a + b + a + b + a +(x_ + x__ + x___, a + b + a + b + a) -Traverse expression. If `is_match` is true, apply `f` to that part of expression tree and reassemble. +julia> Θ = match(p, s); -(Basically `CallableExpressions.expression_map_matched` brought forward to variables in `SimpleExpressions`.) +julia> length(Θ) # 37 matches +37 -## Example -``` -julia> u = x*tanh(exp(x)) -x * tanh(exp(x)) +julia> σ = last(Θ) +(x_ => b, x__ => b, x___ => a + a + a) -julia> SimpleExpressions.map_matched(u, ==(exp(x)), x -> x^2) -x * tanh(exp(x) ^ 2) +julia> p(σ...) # a + a + (a + b + b) +b + b + (a + a + a) ``` """ -function map_matched(x::𝐿, is_match::P, f::F) where {P,F} - is_match(x) ? f(x) : x -end -function map_matched(x::SymbolicExpression, is_match::P, f::F) where {P,F} - # copy of CallableExpressions.expression_map_matched(pred, mapping, u) - # but in SimpleExpressions domain - if is_match(x) - return f(x) +function Base.match(pat::AbstractSymbolic, ex::AbstractSymbolic) + pred(a) = any(any(_is_𝑋(u) for u in s) for s in free_symbols(a)) + if pred(pat) + out = MatchOneToOne((ex,), pat) + out == () && return nothing + return out + else + out = SyntacticMatch(ex, pat) end - isa(x, 𝐿) && return x - children = map_matched.(arguments(x), is_match, f) - maketerm(typeof(x), operation(x), children, metadata(x)) -end - -function _exact_replace(ex, p, q) - map_matched(ex, ==(p), _ -> q) -end - -#= -## replace exact piece of tree with something else -_exact_replace(ex::SymbolicNumber, p, q) = ex == p ? ↑(q) : ex -_exact_replace(ex::SymbolicVariable, p, q) = ex == p ? ↑(q) : ex -_exact_replace(ex::SymbolicParameter, p, q) = ex == p ? ↑(q) : ex -function _exact_replace(ex::SymbolicExpression, p, q) - ex == p && return ↑(q) - op, args = operation(ex), arguments(ex) - args′ = ((a == p ? q : _exact_replace(a, p, q)) for a in args) - maketerm(SymbolicExpression, op, args′, nothing) + out end -=# diff --git a/src/scalar-derivative.jl b/src/scalar-derivative.jl index 4dd29fb..74b8a5e 100644 --- a/src/scalar-derivative.jl +++ b/src/scalar-derivative.jl @@ -21,6 +21,7 @@ julia> D(D(sin(x))) + sin(x) # no simplification! (-(sin(x))) + sin(x) ``` +Not exported. """ D(𝑥::SymbolicNumber, x) = 0 D(𝑥::SymbolicVariable, x) = 𝑥 == x ? 1 : 0 diff --git a/src/simplify.jl b/src/simplify.jl index 61e58d5..e69de29 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -1,27 +0,0 @@ -## ----- Interface -""" - simplify(ex) - -Simplify expression using `Metatheory.jl` and rules on loan from `SymbolicUtils.jl`. -""" -function simplify() -end - -""" - expand(ex) - -Expand terms in an expression using `Metatheory.jl` -""" -function expand() -end - -# some default definitions -# we extend to SymbolicExpression in the Metatheroy extension -for fn ∈ (:simplify, :expand, - :canonicalize, :powsimp, :trigsimp, :logcombine, - :expand_trig, :expand_power_exp, :expand_log) - @eval begin - $fn(ex::AbstractSymbolic) = ex - $fn(eq::SymbolicEquation) = SymbolicEquation($fn.(eq)...) - end -end diff --git a/src/solve.jl b/src/solve.jl index d0daa4b..1305c9e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -41,7 +41,7 @@ A = w * h u = solve(constraint, h) A = A(u) # use equation in replacement -v = solve(D(A, w) ~ 0, w) +v = solve(D(A, w) ~ 0, w) ``` """ CommonSolve.solve(eq::SymbolicEquation, x::𝑉) = _solve(eq.lhs, eq.rhs, x) @@ -136,6 +136,8 @@ julia> a0, as... = cs = SimpleExpressions.coefficients(eq, x) julia> a0 + sum(aᵢ*x^i for (i,aᵢ) ∈ enumerate(Iterators.rest(cs,2)) if !iszero(aᵢ)) -2 + (-2 * p * (x ^ 1)) + ((2 + (-1 * p)) * (x ^ 2)) + (1 * (x ^ 3)) ``` + +Not exported. """ coefficients(ex::SymbolicEquation, x) = coefficients(ex.lhs - ex.rhs, x) function coefficients(ex, x) @@ -154,7 +156,7 @@ function coefficients(ex, x) nms = tuple((SimpleExpressions._aᵢ(i) for i in 0:n)...) NamedTuple{nms}(coeffs) - + end function _aᵢ(i) @@ -435,7 +437,7 @@ function isolate_x(::Val{:→}, ::typeof(*), l, r, x) l′ = one(l) for c ∈ arguments(l) if contains(c, x) - l′ = l′ ⊗ c + l′ = l′ ⊗ c else r = r ⨸ c end @@ -471,7 +473,7 @@ end function isolate_x(::Val{:←}, ::Any, l, r, x) !contains(r, x) && return l, r # leave as is if no x - + op = operation(r) op⁻¹ = get(inverse_functions, op, nothing) @@ -482,7 +484,7 @@ function isolate_x(::Val{:←}, ::Any, l, r, x) return l, r end - + ## l to r: leave x terms, move others @@ -495,7 +497,3 @@ end # apply inverse? - - - - diff --git a/src/terminterface.jl b/src/terminterface.jl index a376f70..70662f1 100644 --- a/src/terminterface.jl +++ b/src/terminterface.jl @@ -22,10 +22,14 @@ TermInterface.isexpr(ex::SymbolicExpression) = true TermInterface.isexpr(ex::AbstractSymbolic) = false -function TermInterface.maketerm(T::Type{<:AbstractSymbolic}, head, children, metadata) +function TermInterface.maketerm(T::Type{<:SymbolicExpression}, head, children, metadata) head(assymbolic.(children)...) end +function TermInterface.maketerm(T::Type{<:SymbolicExpression}, head::SymbolicVariable, children, metadata) + SymbolicExpression(head, children) +end + function TermInterface.maketerm(T::Type{<:SymbolicNumber}, ::Nothing, children, metadata) SymbolicNumber(DynamicConstant(only(children))) end diff --git a/src/types.jl b/src/types.jl index 0ff09b1..b37cb83 100644 --- a/src/types.jl +++ b/src/types.jl @@ -12,6 +12,7 @@ end SymbolicVariable(x::SymbolicVariable) = x SymbolicVariable(x::Symbol) = SymbolicVariable(StaticVariable{x}()) SymbolicVariable(x::AbstractString) = SymbolicVariable(Symbol(x)) +SymbolicVariable(x::Tuple) = SymbolicVariable.(x) struct SymbolicParameter{T <: DynamicVariable} <: AbstractSymbolic u::T @@ -33,7 +34,10 @@ function SymbolicNumber(c::S) where {S <: Number} end Base.zero(::AbstractSymbolic) = SymbolicNumber(0) +Base.zero(::Type{<:AbstractSymbolic}) = SymbolicNumber(0) Base.one(::AbstractSymbolic) = SymbolicNumber(1) +Base.one(::Type{<:AbstractSymbolic}) = SymbolicNumber(1) + # Expressions @@ -46,21 +50,32 @@ function SymbolicExpression(op, children) SymbolicExpression(u) end +## ---------- +# @symbolic_variables has room for functions and guarded functions +# for symbolic functions, different call +struct SymbolicFunction{X, T <: StaticVariable{X}} <: AbstractSymbolic + u::T + SymbolicFunction(u::T) where {X, T <: StaticVariable{X}} = new{X,T}(u) +end +SymbolicFunction(x::Symbol) = SymbolicFunction(StaticVariable{x}()) +SymbolicFunction(x::AbstractString) = SymbolicFunction(Symbol(x)) +(f::SymbolicFunction)(xs...) = SymbolicExpression(f, xs) + +## we *could* do more here adding a guard for matching purposes? +GuardedSymbolicVariable(u;kwargs...) = SymbolicVariable(u) + + +## ---------- # conveniences 𝑉 = Union{SymbolicVariable, SymbolicParameter} -𝐿 = Union{𝑉, SymbolicNumber} +𝐿 = Union{𝑉, SymbolicNumber} # not constant expressions though ## ----- CallableExpressions _Variable = CallableExpressions.ExpressionTypeAliases.Variable - - - ## ----- promotion/conversion - - Base.promote_rule(::Type{<:AbstractSymbolic}, x::Type{T}) where {T <: Number} = AbstractSymbolic Base.convert(::Type{<:AbstractSymbolic}, x::Number) = SymbolicNumber(DynamicConstant(x)) @@ -72,6 +87,7 @@ Base.convert(::Type{<:AbstractSymbolic}, x::SymbolicParameter) = x ## --- CallableExpressions --> SimpleExpression # convert to symbolic; ↑ is an alias +assymbolic(x) = x assymbolic(x::AbstractSymbolic) = x assymbolic(x::Symbol) = SymbolicVariable(x) assymbolic(x::Number) = SymbolicNumber(x) diff --git a/test/basic_tests.jl b/test/basic_tests.jl index 1d7c556..dc8e8c1 100644 --- a/test/basic_tests.jl +++ b/test/basic_tests.jl @@ -140,59 +140,10 @@ end @test u(p=>p₀) isa SimpleExpressions.AbstractSymbolic end -@testset "replace" begin - @symbolic x p - @symbolic ⋯ - - ≈ₑ(u,v) = (x₀ = rand(); u(x₀) ≈ v(x₀)) - - ex = log(1 + x^2) + log(1 + x^3) - @test replace(ex, log=>sin) == sin(1 + (x ^ 2)) + sin(1 + (x ^ 3)) - @test replace(ex, log(1+⋯) => log1p(⋯)) == log1p(x ^ 2) + log1p(x ^ 3) - - ex = log(sin(x)) + tan(sin(x^2)) - @test replace(ex, sin => cos) == log(cos(x)) + tan(cos(x^2)) - @test replace(ex, sin(⋯) => tan(⋯)) == log(tan(x)) + tan(tan(x^2)) - @test replace(ex, sin(⋯) => tan((⋯)/2)) == log(tan(x/2)) + tan(tan(x^2/2)) - @test replace(ex, sin(⋯) => ⋯) == log(x) + tan(x^2) - - ex = (1 + x^2)^2 # outer one - pr = (⋯)^2 => (⋯)^4 - @test replace(ex, pr) == (1 + (x ^ 2)) ^ 4 - @test replace(ex, pr, pr) == (1 + (x ^ 4)) ^ 4 - - ex = sin(x + x*log(x) + cos(p + x + p + x^2)) - @test replace(ex, cos(x + ⋯) => ⋯) == sin(x + (x * log(x)) + p + p + (x ^ 2)) - - @test replace(x, p=>2) == x - @test replace(1 + x^2, x^2 => 2)() == 3 # 1 + 2 evaluates to 3 - - # exact replacement; a bit speedier than `replace(ex, expr=>replacement)` - ex = x^2 + x^4 - @test replace(ex, x^2 => x) == x + x^4 - - ex = x * sin(x) - @test replace(ex, x*sin(x) => x) == x - @test replace(ex*cos(x), x*sin(x) => x) == ex * cos(x) - - u = x + 2 - ex = u + exp(u) # (x + 2 + exp(x+2)) but + is vararg so no match - @test replace(ex, u=>x) == (x + 2 + exp(x)) - - @symbolic y; @symbolic z - @test replace(x*y + z, x*y => pi) == pi + z - @test replace(x*y*z, x*y => pi) == x*y*z - @test replace(2x, 2x => y, x => z) == y - @test replace(2 * (2x), 2x => y, x => z) == 4 * z # y isn't replaced, just x - - # match - @test match(log(1 + ⋯), log(1 + x^2/2 - x^4/24)) ≈ₑ x^2/2 - x^4/24 - @test match((⋯)^(⋯), (x+p)^(x+p)) == x + p - @test isnothing(match(sin(⋯), sin(x)^2)) - +@testset "map_matched" begin # map_matched @symbolic x p - @test map_matched(x*tanh(exp(x)), ==(exp(x)), x -> x^2) == x * tanh(exp(x)^2) + @test SimpleExpressions.map_matched(x*tanh(exp(x)), ==(exp(x)), x -> x^2) == x * tanh(exp(x)^2) end diff --git a/test/runtests.jl b/test/runtests.jl index 47d9fa2..86da6dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,6 @@ using Test import SimpleExpressions: @symbolic_expression include("basic_tests.jl") +include("test_match.jl") ## too slow right now #include("extension_tests.jl") diff --git a/test/test_match.jl b/test/test_match.jl new file mode 100644 index 0000000..533f574 --- /dev/null +++ b/test/test_match.jl @@ -0,0 +1,213 @@ +using Test +using SimpleExpressions +S = SimpleExpressions + +import SimpleExpressions: SyntacticMatch, MatchOneToOne, + MatchSequence,MatchCommutativeSequence +import SimpleExpressions: SymbolicVariable, SymbolicExpression +import SimpleExpressions: @symbolic_variables + +@symbolic x p +@symbolic ⋯ +@symbolic_variables y z a b c +@symbolic_variables x_ x__ x___ y_ y__ y___ z_ z__ z___ + +@symbolic_variables g() f() fₐ() fₘ() fₐₘ() +f ⨝ as = f(as...) + +function S.isassociative(x::S.SymbolicFunction) + nm = string(Symbol(x)) + endswith(nm, "ₐ") && return true + endswith(nm, "ₐₘ") && return true + false +end + +function S.iscommutative(x::S.SymbolicFunction) + nm = string(Symbol(x)) + endswith(nm, "ₘ") && return true + false +end + +## ---- + +@testset "match" begin + + # match 1 + @test match((⋯)^(⋯), (x+p)^(x+p)) == (((⋯) => x + p,),) + + # match 2 wildcards + Θ = match(x_*sin(y_), x*sin(x)) + σ = first(Θ) + @test (y_ => x) ∈ σ && (x_ =>x) ∈ σ && length(σ) == 2 + + # match can have more than 1 substitution + Θ = match(f(x__,y__), f(a,b,c)) + @test length(Θ) == 2 + @test f(x__, y__)(first(Θ)...) ∈ (f((a,b), (c,)), f((a,), (b,c))) + + # empty match returns `nothing` + @test isnothing(match(sin(⋯), sin(x)^2)) +end + +@testset "exact" begin + 𝑝, 𝑠 = cos(sin(a)), cos(sin(a)) + m = SyntacticMatch(𝑠, 𝑝) + @test m == () + + 𝑝, 𝑠 = cos(sin(a)), cos(sin(b)) + m = SyntacticMatch(𝑠, 𝑝) + @test isnothing(m) + + m = SyntacticMatch(sin(cos(a)), cos(a)) + @test isnothing(m) +end + +@testset "associative" begin + 𝑠 = 1 + a + b + 𝑝 = 1 + x_ + Θ = MatchOneToOne((𝑠,), 1 + x_) + @test length(Θ) == 1 + σ = only(Θ) + @test S.sorted_arguments(last(σ[1])) == (a,b) + + Θ = MatchOneToOne((a + b + c,), x__ + y__) + @test length(Θ) == 6 # (c, a+b),(a,c+b),(b,c+a),(c+a,b),(c+b,a), (a+b,c) + + # match + # should not match + 𝑠 = log(1 + x^2/2 - x^4/24) + @test !isnothing(match(log(1 + ⋯), 𝑠)) + @test !isnothing(match(log(1 + x__), 𝑠)) # again x_ like x__ + +end + +@testset "constant patterns" begin + @test MatchSequence((a,b,c), (a,b,b)) == () # no substitutions + @test MatchSequence((a,b,c), (a,b,c)) == ((),) # one trivial substitution +end + +@testset "matched variables" begin + + ss, ps = (a,b,c), (x_,y_,z_) + σ = (x_ => a,) + + ss′, ps′ = S._match_matched_variables(ss, ps, σ) + @test ss′ == (b,c) && ps′ == (y_,z_) + + Θ = MatchCommutativeSequence(ss, ps, nothing, ((),)) + @test length(Θ) == 6 + Θ = MatchCommutativeSequence(ss, ps, nothing, (σ,)) + @test length(Θ) == 2 + +end + + +@testset "non-variable" begin + 𝑝 = fₘ(g(a,x_), g(x_,y_), g(z__)) + 𝑠 = fₘ(g(a,b), g(b,a), g(a,c)) + Θ = MatchOneToOne((𝑠,), 𝑝) + σ = only(Θ) + @test length(σ) == 3 + @test (x_ => b) ∈ σ && (y_ => a) ∈ σ && (z__ => (a, c)) ∈ σ + +end + +@testset "regular variables" begin + 𝑠 = fₘ(a,a,a,b,b,c) + 𝑝 = fₘ(x_,x_,y___) + Θ = MatchOneToOne((𝑠,), 𝑝) + @test length(Θ) == 1 # σ = (x_ => a, y___ => (a, b, b, c)) + @test (x_ => a, y___ => (a, b, b, c)) ∈ Θ # ordering is ok + + 𝑠 = fₐₘ(a,a,a,b,b,c) + 𝑝 = fₐₘ(x_,x_,y___) # associative has x_ like x__ + Θ = MatchOneToOne((𝑠,), 𝑝) + @test length(Θ) == 3 # (x_ => fₐₘ(a, b), y___ => fₐₘ(a, c)) + + +end + +@testset "sequence variables" begin + @symbolic_variables u() uₐ() uₘ() uₐₘ() + + Θ = MatchSequence((a,b,c), (x__, y__), u) + @test length(Θ) == 2 # u(a,b), u(c); u(a), u(b,c) + + Θ = MatchSequence((a,b,c), (x__, y___), u) + @test length(Θ) == 3 # add u(a,b,c),u() + + Θ = MatchSequence((a,b,c), (x___, y___), u) + @test length(Θ) == 4 + + + Θ = MatchSequence((a,b,c), (x__, y__), uₘ) # are these right + @test length(Θ) == 2 # + + Θ = MatchSequence((a,b,c), (x__, y___), uₘ) + @test length(Θ) == 3 + + + Θ = MatchSequence((a,b,c), (x___, y___), uₐₘ) + @test length(Θ) == 4 + + +end + +@testset "replace head" begin + # replace operation + ex = log(1 + x^2) + log(1 + x^3) + @test replace(ex, log=>sin) == sin(1 + (x ^ 2)) + sin(1 + (x ^ 3)) + + @symbolic_variables f() g() + @test replace(f(a,a,b), f(x__) => g(x__)) == g((a,a,b)) # not g(a,a,b) +end + +@testset "replace" begin + # with wildcards + ≈ₑ(u,v) = (x₀ = rand(); u(x₀) ≈ v(x₀)) + ≈ₚ(u,v) = (x₀ = rand(); p₀ = rand(); u(x₀, p₀) ≈ v(x₀, p₀)) + + + # replace parts + ex = log(1 + x^2) + log(1 + x^3) + @test replace(ex, log(1+x__) => log1p(x__)) == log1p(x ^ 2) + log1p(x ^ 3) + + ex = log(sin(x)) + tan(sin(x^2)) + @test replace(ex, sin => cos) == log(cos(x)) + tan(cos(x^2)) + @test replace(ex, sin(⋯) => tan(⋯)) == log(tan(x)) + tan(tan(x^2)) + @test replace(ex, sin(⋯) => tan((⋯)/2)) == log(tan(x/2)) + tan(tan(x^2/2)) + @test replace(ex, sin(⋯) => ⋯) == log(x) + tan(x^2) + + ex = (1 + x^2)^2 # outer one is peeled off first by replace + pr = (⋯)^2 => (⋯)^4 + @test replace(ex, pr) == (1 + (x ^ 2)) ^ 4 + @test replace(ex, pr, pr) == (1 + (x ^ 4)) ^ 4 + + + ex = sin(x + x*log(x) + cos(p + x + p + x^2)) + @test replace(ex, cos(x + x__) => x__) ≈ₚ sin(x + (x * log(x)) + p + p + (x ^ 2)) + + @test replace(x, p=>2) == x + @test replace(1 + x^2, x^2 => 2)() == 3 # 1 + 2 evaluates to 3 + + + # x_ matches different parts of expression tree in replace + ex = sin(cos(a))*cos(b) + @test replace(ex, cos(x_) => tan(x_)) == sin(tan(a)) * tan(b) + + # no variable in substitution + @test replace(sin(a), sin(x_) => x) == x + @test replace(sin(a), sin(x_) => x_) == a + @test replace(sin(a), sin(x_) => 2) == 2 +end + +@testset "replace exact" begin + # no wild card + ex = x^2 + x^4 + @test replace(ex, x^2 => x) == x + x^4 + + ex = x * sin(x) + @test replace(ex, x*sin(x) => x) == x + @test replace(ex*cos(x), x*sin(x) => x) == ex * cos(x) + +end