From 3f1fd58108a71e7e1ffae97cc067aa1c8a8aca49 Mon Sep 17 00:00:00 2001 From: jverzani Date: Fri, 17 Jan 2025 17:28:40 -0500 Subject: [PATCH] light simplification of terms --- Project.toml | 3 +- docs/src/index.md | 2 +- ext/SimpleExpressionsMetatheoryExt.jl | 414 ++++++++++++++++++++++++++ src/SimpleExpressions.jl | 4 +- src/combine.jl | 194 ++++++++++++ src/metatheory.jl | 403 ------------------------- src/ops.jl | 147 ++++----- src/scalar-derivative.jl | 98 +++--- src/simplify.jl | 27 ++ src/solve.jl | 70 ++--- test/basic_tests.jl | 32 +- 11 files changed, 820 insertions(+), 574 deletions(-) create mode 100644 ext/SimpleExpressionsMetatheoryExt.jl create mode 100644 src/combine.jl diff --git a/Project.toml b/Project.toml index 8874b91..832f122 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleExpressions" uuid = "deba94f7-f32a-40ad-b45e-be020a5ded2f" authors = ["jverzani and contributors"] -version = "1.1.9" +version = "1.1.10" [deps] CallableExpressions = "391672e0-bbe4-4ab4-8bc9-b89a79cbc2f0" @@ -22,6 +22,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] SimpleExpressionsAbstractTreesExt = "AbstractTrees" SimpleExpressionsLatexifyExt = "Latexify" +#SimpleExpressionsMetatheoryExt = "Metatheory" SimpleExpressionsRecipesBaseExt = "RecipesBase" SimpleExpressionsRootsExt = "Roots" SimpleExpressionsSpecialFunctionsExt = "SpecialFunctions" diff --git a/docs/src/index.md b/docs/src/index.md index d759c5a..a8dfb5b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -271,4 +271,4 @@ u = D(exp(x) * (sin(3x) + sin(101x)), x) #### Simplification -No simplification is done so the expressions can quickly become unwieldy. There is `TermInterface` support, so--in theory--rewriting of expressions, as is possible with the `Metatheory.jl` package, is supported. The scaffolding is in place, but waits for the development version to be tagged. +No simplification is done so the expressions can quickly become unwieldy. The unexported `combine` does light simplfication. There is `TermInterface` support, so--in theory--rewriting of expressions, as is possible with the `Metatheory.jl` package, is supported. The scaffolding is in place, but waits for the development version to be tagged. diff --git a/ext/SimpleExpressionsMetatheoryExt.jl b/ext/SimpleExpressionsMetatheoryExt.jl new file mode 100644 index 0000000..9656632 --- /dev/null +++ b/ext/SimpleExpressionsMetatheoryExt.jl @@ -0,0 +1,414 @@ +module SimpleExtensionsMetatheoryExt + +## --------------- + +import Combinatorics: combinations, permutations +using Metatheory +using SimpleExpressions +import SimpleExpressions: SymbolicNumber, SymbolicParameter, + SymbolicVariable, SymbolicExpression, AbstractSymbolic +import SimpleExpressions: + simplify, powsimp,trigsimp,logcombine, + expand,expand_trig,expand_power_exp, expand_log, + canonoicalize + +## ----- predicates ----- +is_literal_number(::SymbolicNumber) = true +is_literal_number(::Number) = true +is_literal_number(::Any) = false + +isminusone(x::Number) = x == -1 +isminusone(x::SymbolicNumber) = isminusone(x()) +isminusone(::AbstractSymbolic) = false + +istwo(x::Number) = x == 2 +istwo(x::SymbolicNumber) = istwo(x()) +istwo(::AbstractSymbolic) = false + +## ---- SymbolicUtils ---- + +#= +Lifted and modified from MIT licensed [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/LICENSE.md) + +[Rules](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rule.jl) are mostly based on rules in that package. +=# + +# recreate @acrule and @ordered_acrule +struct ACRule{F,R} + sets::F + rule::R + arity::Int +end + +macro acrule(expr) + arity = length(expr.args[2].args[2:end]) + quote + ACRule(permutations, $(esc(:(@rule($(expr))))), $arity) + end +end + +macro ordered_acrule(expr) + arity = length(expr.args[2].args[2:end]) + quote + ACRule(combinations, $(esc(:(@rule($(expr))))), $arity) + end +end + +function (acr::ACRule)(term) + r = acr.rule + if !iscall(term) + r(term) + else + f = operation(term) + # # Assume that the matcher was formed by closing over a term + # if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. + # return nothing + # end + + args = arguments(term) + + itr = acr.sets(eachindex(args), acr.arity) + + for inds in itr + result = r(f(args[inds]...)) #Term{T}(f, @views args[inds])) + if result !== nothing + # Assumption: inds are unique + length(args) == length(inds) && return result + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], nothing) # metadata(term)) + end + end + end +end + +## ---- some predicates used in SymbolicUtils rules +function isnotflat(⋆) + function (x) + args = arguments(x) + for t in args + if is_operation(⋆)(t) + return true + end + end + return false + end +end + +function flatten_term(⋆, x) + args = arguments(x) + # flatten nested ⋆ + flattened_args = [] + for t in args + if is_operation(⋆)(t) + append!(flattened_args, arguments(t)) + else + push!(flattened_args, t) + end + end + maketerm(SymbolicExpression, ⋆, flattened_args, metadata(x)) +end + +hasrepeats(::SimpleExpressions.AbstractSymbolic)= false +function hasrepeats(x′::SimpleExpressions.SymbolicExpression) + x = TermInterface.arguments(x′) + length(x) <= 1 && return false + for i=1:length(x)-1 + if isequal(x[i], x[i+1]) + return true + end + end + return false +end + +_merge_op(::typeof(*), a, b) = a^b +_merge_op(::typeof(+), a, b) = b*a +function merge_repeats(op, xs) + + length(xs) <= 1 && return xs + merged = () + d = Dict{Any, Int}() + for k in xs + cnt = get(d, k, 0) + d[k] = cnt + 1 + end + + return tuple((v==1 ? k : _merge_op(op, k,v) for (k,v) in pairs(d))...) + +end + +function has_trig_exp(term) + !iscall(term) && return false + fns = (sin, cos, tan, cot, sec, csc, exp, cosh, sinh) + op = operation(term) + + if Base.@nany 9 i->fns[i] === op + return true + else + return any(has_trig_exp, arguments(term)) + end +end + + +needs_sorting(f) = x -> is_operation(f)(x) && !issorted(arguments(x)) +needs_sorting₊ = needs_sorting(+) # issue with using rhs as predicate? +needs_sortingₓ = needs_sorting(*) + +function sort_args(f, t) + args = arguments(t) + args = merge_repeats(f, args) # had issue with hasrepeats, slipped in here + if length(args) < 2 + return maketerm(typeof(t), f, args, metadata(t)) + elseif length(args) == 2 + x, y = args + return maketerm(typeof(t), f, x < y ? [x,y] : [y,x], metadata(t)) + end + args = args isa Tuple ? [args...] : args + maketerm(typeof(t), f, TupleTools.sort(args), metadata(t)) +end + +# issue is ~~x +Base.view(t::NTuple{N, AbstractSymbolic}, ind::UnitRange) where {N} = t[ind] + +## rules for simplification +CANONICALIZE_PLUS = [ + +# @rule(~x::isnotflat(+) => flatten_term(+, ~x)) + + @rule(~x::needs_sorting₊ => sort_args(+, ~x)) # also merge + @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b) + #XXX @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) ## JUST WRONG! + @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)) + @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x)) + # @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...)) # XXX p_var issue + + @ordered_acrule((~z::iszero + ~x) => ~x) + @rule(+(~x) => ~x) + + #@rule(-(~x) => (-1) * ~x) + @rule(-(~x, ~y) => +(~x, (-1) * ~y)) + + @acrule(~x + ~c::isminusone * ~x => zero(~x)) +] + +CANONICALIZE_TIMES = [ +# @rule(~x::isnotflat(*) => flatten_term(*, ~x)) + @rule(~x::needs_sortingₓ => sort_args(*, ~x)) + + @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) + # @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) + + @acrule((~y)^(~n) * ~y => (~y)^(~n+1)) + + @ordered_acrule((~z::isone * ~x) => ~x) + @ordered_acrule((~z::iszero * ~x) => ~z) + + @rule(~x / ~x => one(~x)) + @rule(*(~x,~xs...) / (~x,~ys...) => *(~xs...) / *(~ys...)) + @rule(~x / (~x, ~ys...) => one(~x) / *(~ys...)) + @rule(*(~x,~xs...) / ~x => *(~xs...)) + + @acrule(~x * (~x)^(~c::isone) => one(~x)) + @rule(*(~x) => ~x) +] + +CANONICALIZE_POW = [ + @rule(^(*(~~x), ~y::isinteger) => *(map(a->a^~y, ~~x)...)) + @rule((((~x)^(~p::isinteger))^(~q::isinteger)) => (~x)^((~p)*(~q))) + @rule(^(~x, ~z::iszero) => 1) + @rule(^(~x, ~z::isone) => ~x) + # @rule(inv(~x) => 1/(~x)) +] + +PLUS_DISTRIBUTE = [ + @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) + @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)) +] + + +POW_RULES = [ + @rule(^(~x::isone, ~z) => 1) + @ordered_acrule((~x)^(~a) * (~x)^(~b) => (~x)^(~a + ~b)) # always + @ordered_acrule((~x)^(~a) * (~y)^(~a) => (~x*~y)^(~a)) # x,y ≥ 0; a real + # (x^a)^b => x^(a*b) is in canonicalize; b \in Z +] + +ASSORTED_RULES = [ + @rule(identity(~x) => ~x) + @rule((~x::isone) \ ~y => ~y) + @rule(~x \ ~y => ~y / (~x)) + @rule(one(~x) => 1) + @rule(zero(~x) => 0) + @rule(conj(~x::isreal) => ~x) + @rule(real(~x::isreal) => ~x) + @rule(imag(~x::isreal) => 0) + #@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z) + @rule(ifelse(~x, ~y, ~y) => ~y) +] + +TRIG_RULES = [ + # @acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)) + # @acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)) + @acrule(sin(~x)^(~c::istwo) + cos(~x)^(~c::istwo) => one(~x)) + @acrule(sin(~x)^(~c::istwo) - (~c::isone) => -1*cos(~x)^2) + @acrule(cos(~x)^(~c::istwo) - (~c::isone) => -1*sin(~x)^2) + + @acrule(cos(~x)^(~c::istwo) + (~d::isminusone)*sin(~x)^(~c::istwo) => cos(2 * ~x)) + @acrule(sin(~x)^(~c::istwo) + (~d::isminusone)*cos(~x)^(~c::istwo) => -cos(2 * ~x)) + @acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2) + + @acrule(tan(~x)^(~c::istwo) + (~d::isminusone)*sec(~x)^(~c::istwo) => one(~x)) + @acrule((~d::isminusone)*tan(~x)^(~c::istwo) + sec(~x)^(~c::istwo) => one(~x)) + @acrule(tan(~x)^(~c::istwo) + (~d::isone) => sec(~x)^2) + @acrule(sec(~x)^(~c::istwo) + (~d::isminusone) => tan(~x)^2) + + @acrule(cot(~x)^(~c::istwo) + (~d::isminusone)*csc(~x)^(~c::istwo) => one(~x)) + @acrule(cot(~x)^(~c::istwo) + (~d::isone) => csc(~x)^2) + @acrule(csc(~x)^(~c::istwo) + (~d::isminusone) => cot(~x)^2) + + @acrule(cosh(~x)^(~c::istwo) + (~d::isminusone)*sinh(~x)^(~c::istwo) => one(~x)) + @acrule(cosh(~x)^(~c::istwo) + (~d::isminusone) => sinh(~x)^2) + @acrule(sinh(~x)^(~c::istwo) + (~d::isone) => cosh(~x)^2) + + @acrule(cosh(~x)^(~c::istwo) + sinh(~x)^(~c::istwo) => cosh(2 * ~x)) + @acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2) +] + +EXP_RULES = [ + @acrule(exp(~x) * exp(~y) => iszero(~x + ~y) ? 1 : exp(~x + ~y)) + @rule(exp(~x)^(~y) => exp(~x * ~y)) +] + +LOG_RULES = [ + @acrule(log(~x) + log(~y) => log(~x * ~y)), + @acrule(~n * log(~x) => log((~x)^(~n))) +] + +CANONICALIZE = CANONICALIZE_PLUS ∪ CANONICALIZE_TIMES ∪ CANONICALIZE_POW + +## rules for expansion + +_expand_minus = [ + @rule(-(~a + ~b) => -~a + -~b) + @rule((~c::isone) * (+(~~xs...)) => +(~~xs...)) + @rule((~c::isminusone) * (+(~~xs...)) => +((-).(~~xs)...)) + @rule(~a - ~a => zero(~a)) + @rule(-~a + ~a => zero(~a)) +] + + + +_expand_distributive = [ + @rule(~z*(~x + ~y) => ~z*~x + ~z*~y) + @rule((~x + ~y) * ~z => ~z*~x + ~z*~y) + @rule(~z * (+(~~xs...)) => sum(~z*x for x in ~~xs)) + @rule(+(~~xs...) * ~z => sum(~z*x for x in ~~xs)) + + @rule(~z*(~x - ~y) => ~z*~x - ~z*~y) + @rule((~x - ~y) * ~z => ~z*~x - ~z*~y) +] + +_expand_binom = [ + @rule((~x + ~y)^(~c::isone) => ~x + ~y) + @rule((~x + ~y)^(~c::istwo) => (~x)^2 + 2*~x*~y + (~y)^2) + @rule((~x + ~y)^(~n::isinteger) => sum(binomial(Int(~n), k) * (~x)^k * (~y)^((~n)-k) for k in 0:Int(~n))) +] + + +_expand_trig = [ + @rule(sin(~c::istwo * ~a) => 2sin(~a)*cos(~a)) + @rule sin(~a + ~b) => sin(~a)*cos(~b) + cos(~a)*sin(~b) + @rule cos(~c::istwo * ~a) => cos(~a)^2 - sin(~a)^2 + @rule cos(~a + ~b) => cos(~a)*cos(~b) - sin(~a)*sin(~b) + @rule sec(~a) => 1 / cos(~a) + @rule csc(~a) => 1 / sin(~a) + @rule tan(~a) => sin(~a)/cos(~a) + @rule cot(~a) => cos(~a)/sin(~a) +] + + +_expand_power = [ + @rule (~x)^(~a+~b) => (~x)^(~a) * (~x)^(~b) + @rule((~x*~y)^(~a) => (~x)^(~a) * (~y)^(~a)) +] + +_expand_log = [ + @rule(log(~x*~y) => log(~x) + log(~y)) + @rule log((~x)^(~n)) => ~n * log(~x) +] + +_expand_misc = [ + @rule( -(~a) => (-1)*~a) + @rule(((~c)::isone/~a) * ~a => one(~a)) + @rule(~a * ((~c)::isone/~a) => one(~a)) + #@rule(/(~a,~b) => *(~a, (~b)^(-1))) +] + + +# make methods for expressions +function simplify(ex::SymbolicExpression) + ex = rewrite(ex, CANONICALIZE) + theories = (PLUS_DISTRIBUTE, + POW_RULES, + ASSORTED_RULES, + TRIG_RULES, + EXP_RULES, + LOG_RULES + ) + ex = rewrite(ex, reduce(∪, theories)) + ex = rewrite(ex, CANONICALIZE) +end + +function expand(ex::SymbolicExpression) + theories = ( + _expand_minus, + _expand_distributive, _expand_binom, _expand_trig, + _expand_power, _expand_log, + _expand_misc + ) + ex = rewrite(ex, reduce(∪,theories)) +end + + +# +canonicalize(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE) +powsimp(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ POW_RULES ∪ EXP_RULES) +trigsimp(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ TRIG_RULES) +logcombine(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ LOG_RULES) + +expand_trig(ex::SymbolicExpression) = rewrite(ex, _expand_trig) +expand_power_exp(ex::SymbolicExpression) = rewrite(ex, _expand_power) +expand_log(ex::SymbolicExpression) = rewrite(ex, _expand_log) + + +# function to run quickly and make terms more nice +# used by show +function _canon(x) + rules = [ + # sort + @rule(~x::needs_sorting₊ => sort_args(+, ~x)) # also merge + @rule(~x::needs_sortingₓ => sort_args(*, ~x)) + + # combine terms + @rule(~x + ~x => 2x) + @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)) + @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) + @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) + + # additive identity + @acrule(~z::iszero + ~x => ~x) + @acrule(+(~z::iszero, ~~xs...) => +(~~xs...)) + + # multiplicative zero + @ordered_acrule(~z::iszero * ~x => zero(~x)) + @ordered_acrule(*(~z::iszero, ~~xs...) => zero(~z)) + + # multiplicative identity + @acrule(~z::isone * ~x => ~x) + @acrule(*(~z::isone, ~~xs...) => *(~~xs...)) + + + ] + + rewrite(x, rules) +end + +end diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index b26c61a..aa4f1cd 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -27,6 +27,7 @@ include("decl.jl") include("equations.jl") include("terminterface.jl") include("ops.jl") +include("combine.jl") include("show.jl") include("introspection.jl") include("call.jl") @@ -37,8 +38,7 @@ include("generators.jl") include("scalar-derivative.jl") include("solve.jl") -#include("simplify.jl") -#include("metatheory.jl") +#include("simplify.jl") # wait for Metatheory v3.0 tag end diff --git a/src/combine.jl b/src/combine.jl new file mode 100644 index 0000000..9daf141 --- /dev/null +++ b/src/combine.jl @@ -0,0 +1,194 @@ +# quickish method to combine terms in an expression + +# ax + bx -> (a+b)x +# x^n*x^m -> x^(n+m) +""" + combine(ex) + +Lightly simplify symbolic expressions. + + +## Example + +```@repl combine +julia> using SimpleExpressions + +julia> import SimpleExpressions: combine + +julia> @symbolic x +(x,) + +julia> ex = 1 + x + 2x + 3x +1 + x + (2 * x) + (3 * x) + +julia> combine(ex) +1 + (6 * x) + +julia> ex = 1 + x^2 + 2x^2 + 3x*x + x^4/x +1 + (x ^ 2) + (2 * (x ^ 2)) + (3 * x * x) + ((x ^ 4) / x) + +julia> combine(ex) +1 + (6 * (x ^ 2)) + (x ^ 3) + +``` + +Not exported. + +""" +function combine(ex) + c, d = ATERM(ex) + c + sum(isone(k) ? v : k*v for (v,k) ∈ d if !iszero(k); init=0) +end + +## ---- experimental +## SymEngine uses this structure to add +## c + (c₁,T₁) + (c₂,T₂) + ⋯ +## uses a dict to store Tᵢ => cᵢ +## TERM should have + or * types (powers or coefficients?) +struct Term + constant + terms +end + +function Base.iterate(t::Term, state=nothing) + isnothing(state) && return t.constant, 1 + state == 1 && return t.terms, 2 + nothing +end + +# A term c + k*v +ATERM(ex::SymbolicNumber, d=IdDict()) = Term(ex, d) +function ATERM(x::𝑉, d=IdDict()) + d[x] = get(d, x, 0) + 1 + Term(0, d) +end + +ATERM(x::SymbolicExpression, d=IdDict()) = ATERM(operation(x), x, d) + +function ATERM(::typeof(*), x, d=IdDict()) + c,dx = MTERM(x, IdDict()) + x′ = prod(isone(k) ? v : v^k for (v,k) ∈ dx if !iszero(k); init=1) + d[x′] = get(d, x′, 0) + c + Term(0, d) +end + + +function ATERM(::typeof(/), x, d=IdDict()) + c, dx = MTERM(x) + e = c * prod(isone(k) ? v : v^k for (v,k) ∈ dx if !iszero(k); init=1) + d[e] = get(d, e, 0) + 1 + return Term(0,d) + + + a, b = arguments(x) + ac, ad = ATERM(a) + bc, bd = ATERM(b) + c = iszero(bc) ? ac : ac / bc + denom = prod(v*k for (v,k) ∈ bd; init=1) + if isone(denom) + Term(c, copy(av)) + else + d = IdDict() + for (v,k) ∈ ad + vv = v/denom + d[vv] = get(d, vv, 0) + k + end + Term(c, d) + end +end + +# (cxyz)^n -> c^n, x^n y^n z^n => (0, (x^n y^n z^n,c^n) +function ATERM(::typeof(^), x, d) + xc, xd = MTERM(x, IdDict()) + v = prod(isone(k) ? v : v^k for (v,k) ∈ xd; init=1) + d[v] = get(d, v, 0) + xc + Term(0,d) +end + +function ATERM(::typeof(-), x,d) + a, b = arguments(x) + TERM(a + (-b)) +end + +function ATERM(::Any, x, d) + d[x] = get(d,x,0) + 1 + Term(0, d) +end + +function ATERM(::typeof(+), x, d) + c = 0 + for a in arguments(x) + ca, d = ATERM(a,d) + c = c + ca + end + Term(c, d) +end + +## --- multiplicative terms simplified + +MTERM(x::SymbolicNumber, d= IdDict()) = Term(x, d) +function MTERM(x::SymbolicVariable, d = IdDict()) + d[x] = get(d, x, 0) + 1 + Term(1, d) +end +function MTERM(x::SymbolicParameter, d=IdDict()) + d[x] = get(d, x, 0) + 1 + Term(1, d) +end +MTERM(x::SymbolicExpression, d=IdDict()) = MTERM(operation(x), x, d) + +function MTERM(::Any, x, d) + d[x] = get(d, x, 0) + 1 + Term(1, d) +end + +function MTERM(::typeof(*), x, d) + cs,ts = tuplesplit(Base.Fix2(isa, SymbolicNumber), sorted_arguments(x)) + c = prod(cs, init=1) + for t ∈ ts + ct,d = MTERM(t, d) + c = c * ct + end + Term(c, d) +end + +function MTERM(::typeof(^), x, d) + a, b = arguments(x) + if iscall(a) + cs,ts = tuplesplit(Base.Fix2(isa, SymbolicNumber), sorted_arguments(a)) + if isnumeric(b) && b() < 0 + c = prod((1/cᵢ)^b for cᵢ in cs; init=1) + else + c = prod(cᵢ^b for cᵢ in cs; init=1) + end + for t ∈ ts + d[t] = get(d,t,0) + b + end + elseif isconstant(a) && isconstant(b) + return Term(a^b, d) + else + c = 1 + d[a] = get(d, a, 0) + b + end + Term(c, d) +end + +function MTERM(::typeof(/), x, d) + a, b = arguments(x) + ac, ad = MTERM(a, d) + if is_operation(*)(b) + bs′ = tuple((SymbolicExpression(^, (b, -1)) for b in arguments(b))...) + b′ = maketerm(SymbolicExpression, *, bs′, nothing) + bc, bd = MTERM(b′,ad) + else + bc, bd′ = MTERM(b, IdDict()) + bd = copy(ad) + for (v,k) ∈ bd′ + bd[v] = get(d,v,0) - k + end + end + c = ac / bc + Term(c, bd) +end + + diff --git a/src/metatheory.jl b/src/metatheory.jl index 036804d..b3e3268 100644 --- a/src/metatheory.jl +++ b/src/metatheory.jl @@ -26,406 +26,3 @@ for fn ∈ (:simplify, :expand, end end -## --------------- - -import Combinatorics: combinations, permutations -using Metatheory - -## ----- predicates ----- -is_literal_number(::SymbolicNumber) = true -is_literal_number(::Number) = true -is_literal_number(::Any) = false - -isminusone(x::Number) = x == -1 -isminusone(x::SymbolicNumber) = isminusone(x()) -isminusone(::AbstractSymbolic) = false - -istwo(x::Number) = x == 2 -istwo(x::SymbolicNumber) = istwo(x()) -istwo(::AbstractSymbolic) = false - -## ---- SymbolicUtils ---- - -#= -Lifted and modified from MIT licensed [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/LICENSE.md) - -[Rules](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rule.jl) are mostly based on rules in that package. -=# - -# recreate @acrule and @ordered_acrule -struct ACRule{F,R} - sets::F - rule::R - arity::Int -end - -macro acrule(expr) - arity = length(expr.args[2].args[2:end]) - quote - ACRule(permutations, $(esc(:(@rule($(expr))))), $arity) - end -end - -macro ordered_acrule(expr) - arity = length(expr.args[2].args[2:end]) - quote - ACRule(combinations, $(esc(:(@rule($(expr))))), $arity) - end -end - -function (acr::ACRule)(term) - r = acr.rule - if !iscall(term) - r(term) - else - f = operation(term) - # # Assume that the matcher was formed by closing over a term - # if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. - # return nothing - # end - - args = arguments(term) - - itr = acr.sets(eachindex(args), acr.arity) - - for inds in itr - result = r(f(args[inds]...)) #Term{T}(f, @views args[inds])) - if result !== nothing - # Assumption: inds are unique - length(args) == length(inds) && return result - return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], nothing) # metadata(term)) - end - end - end -end - -## ---- some predicates used in SymbolicUtils rules -function isnotflat(⋆) - function (x) - args = arguments(x) - for t in args - if is_operation(⋆)(t) - return true - end - end - return false - end -end - -function flatten_term(⋆, x) - args = arguments(x) - # flatten nested ⋆ - flattened_args = [] - for t in args - if is_operation(⋆)(t) - append!(flattened_args, arguments(t)) - else - push!(flattened_args, t) - end - end - maketerm(SymbolicExpression, ⋆, flattened_args, metadata(x)) -end - -hasrepeats(::SimpleExpressions.AbstractSymbolic)= false -function hasrepeats(x′::SimpleExpressions.SymbolicExpression) - x = TermInterface.arguments(x′) - length(x) <= 1 && return false - for i=1:length(x)-1 - if isequal(x[i], x[i+1]) - return true - end - end - return false -end - -_merge_op(::typeof(*), a, b) = a^b -_merge_op(::typeof(+), a, b) = b*a -function merge_repeats(op, xs) - - length(xs) <= 1 && return xs - merged = () - d = Dict{Any, Int}() - for k in xs - cnt = get(d, k, 0) - d[k] = cnt + 1 - end - - return tuple((v==1 ? k : _merge_op(op, k,v) for (k,v) in pairs(d))...) - -end - -function has_trig_exp(term) - !iscall(term) && return false - fns = (sin, cos, tan, cot, sec, csc, exp, cosh, sinh) - op = operation(term) - - if Base.@nany 9 i->fns[i] === op - return true - else - return any(has_trig_exp, arguments(term)) - end -end - - -needs_sorting(f) = x -> is_operation(f)(x) && !issorted(arguments(x)) -needs_sorting₊ = needs_sorting(+) # issue with using rhs as predicate? -needs_sortingₓ = needs_sorting(*) - -function sort_args(f, t) - args = arguments(t) - args = merge_repeats(f, args) # had issue with hasrepeats, slipped in here - if length(args) < 2 - return maketerm(typeof(t), f, args, metadata(t)) - elseif length(args) == 2 - x, y = args - return maketerm(typeof(t), f, x < y ? [x,y] : [y,x], metadata(t)) - end - args = args isa Tuple ? [args...] : args - maketerm(typeof(t), f, TupleTools.sort(args), metadata(t)) -end - -# issue is ~~x -Base.view(t::NTuple{N, AbstractSymbolic}, ind::UnitRange) where {N} = t[ind] - -## rules for simplification -CANONICALIZE_PLUS = [ - -# @rule(~x::isnotflat(+) => flatten_term(+, ~x)) - - @rule(~x::needs_sorting₊ => sort_args(+, ~x)) # also merge - @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b) - #XXX @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) ## JUST WRONG! - @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)) - @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x)) - # @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...)) # XXX p_var issue - - @ordered_acrule((~z::iszero + ~x) => ~x) - @rule(+(~x) => ~x) - - #@rule(-(~x) => (-1) * ~x) - @rule(-(~x, ~y) => +(~x, (-1) * ~y)) - - @acrule(~x + ~c::isminusone * ~x => zero(~x)) -] - -CANONICALIZE_TIMES = [ -# @rule(~x::isnotflat(*) => flatten_term(*, ~x)) - @rule(~x::needs_sortingₓ => sort_args(*, ~x)) - - @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) - # @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) - - @acrule((~y)^(~n) * ~y => (~y)^(~n+1)) - - @ordered_acrule((~z::isone * ~x) => ~x) - @ordered_acrule((~z::iszero * ~x) => ~z) - - @rule(~x / ~x => one(~x)) - @rule(*(~x,~xs...) / (~x,~ys...) => *(~xs...) / *(~ys...)) - @rule(~x / (~x, ~ys...) => one(~x) / *(~ys...)) - @rule(*(~x,~xs...) / ~x => *(~xs...)) - - @acrule(~x * (~x)^(~c::isone) => one(~x)) - @rule(*(~x) => ~x) -] - -CANONICALIZE_POW = [ - @rule(^(*(~~x), ~y::isinteger) => *(map(a->a^~y, ~~x)...)) - @rule((((~x)^(~p::isinteger))^(~q::isinteger)) => (~x)^((~p)*(~q))) - @rule(^(~x, ~z::iszero) => 1) - @rule(^(~x, ~z::isone) => ~x) - # @rule(inv(~x) => 1/(~x)) -] - -PLUS_DISTRIBUTE = [ - @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) - @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)) -] - - -POW_RULES = [ - @rule(^(~x::isone, ~z) => 1) - @ordered_acrule((~x)^(~a) * (~x)^(~b) => (~x)^(~a + ~b)) # always - @ordered_acrule((~x)^(~a) * (~y)^(~a) => (~x*~y)^(~a)) # x,y ≥ 0; a real - # (x^a)^b => x^(a*b) is in canonicalize; b \in Z -] - -ASSORTED_RULES = [ - @rule(identity(~x) => ~x) - @rule((~x::isone) \ ~y => ~y) - @rule(~x \ ~y => ~y / (~x)) - @rule(one(~x) => 1) - @rule(zero(~x) => 0) - @rule(conj(~x::isreal) => ~x) - @rule(real(~x::isreal) => ~x) - @rule(imag(~x::isreal) => 0) - #@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z) - @rule(ifelse(~x, ~y, ~y) => ~y) -] - -TRIG_RULES = [ - # @acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)) - # @acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)) - @acrule(sin(~x)^(~c::istwo) + cos(~x)^(~c::istwo) => one(~x)) - @acrule(sin(~x)^(~c::istwo) - (~c::isone) => -1*cos(~x)^2) - @acrule(cos(~x)^(~c::istwo) - (~c::isone) => -1*sin(~x)^2) - - @acrule(cos(~x)^(~c::istwo) + (~d::isminusone)*sin(~x)^(~c::istwo) => cos(2 * ~x)) - @acrule(sin(~x)^(~c::istwo) + (~d::isminusone)*cos(~x)^(~c::istwo) => -cos(2 * ~x)) - @acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2) - - @acrule(tan(~x)^(~c::istwo) + (~d::isminusone)*sec(~x)^(~c::istwo) => one(~x)) - @acrule((~d::isminusone)*tan(~x)^(~c::istwo) + sec(~x)^(~c::istwo) => one(~x)) - @acrule(tan(~x)^(~c::istwo) + (~d::isone) => sec(~x)^2) - @acrule(sec(~x)^(~c::istwo) + (~d::isminusone) => tan(~x)^2) - - @acrule(cot(~x)^(~c::istwo) + (~d::isminusone)*csc(~x)^(~c::istwo) => one(~x)) - @acrule(cot(~x)^(~c::istwo) + (~d::isone) => csc(~x)^2) - @acrule(csc(~x)^(~c::istwo) + (~d::isminusone) => cot(~x)^2) - - @acrule(cosh(~x)^(~c::istwo) + (~d::isminusone)*sinh(~x)^(~c::istwo) => one(~x)) - @acrule(cosh(~x)^(~c::istwo) + (~d::isminusone) => sinh(~x)^2) - @acrule(sinh(~x)^(~c::istwo) + (~d::isone) => cosh(~x)^2) - - @acrule(cosh(~x)^(~c::istwo) + sinh(~x)^(~c::istwo) => cosh(2 * ~x)) - @acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2) -] - -EXP_RULES = [ - @acrule(exp(~x) * exp(~y) => iszero(~x + ~y) ? 1 : exp(~x + ~y)) - @rule(exp(~x)^(~y) => exp(~x * ~y)) -] - -LOG_RULES = [ - @acrule(log(~x) + log(~y) => log(~x * ~y)), - @acrule(~n * log(~x) => log((~x)^(~n))) -] - -CANONICALIZE = CANONICALIZE_PLUS ∪ CANONICALIZE_TIMES ∪ CANONICALIZE_POW - -## rules for expansion - -_expand_minus = [ - @rule(-(~a + ~b) => -~a + -~b) - @rule((~c::isone) * (+(~~xs...)) => +(~~xs...)) - @rule((~c::isminusone) * (+(~~xs...)) => +((-).(~~xs)...)) - @rule(~a - ~a => zero(~a)) - @rule(-~a + ~a => zero(~a)) -] - - - -_expand_distributive = [ - @rule(~z*(~x + ~y) => ~z*~x + ~z*~y) - @rule((~x + ~y) * ~z => ~z*~x + ~z*~y) - @rule(~z * (+(~~xs...)) => sum(~z*x for x in ~~xs)) - @rule(+(~~xs...) * ~z => sum(~z*x for x in ~~xs)) - - @rule(~z*(~x - ~y) => ~z*~x - ~z*~y) - @rule((~x - ~y) * ~z => ~z*~x - ~z*~y) -] - -_expand_binom = [ - @rule((~x + ~y)^(~c::isone) => ~x + ~y) - @rule((~x + ~y)^(~c::istwo) => (~x)^2 + 2*~x*~y + (~y)^2) - @rule((~x + ~y)^(~n::isinteger) => sum(binomial(Int(~n), k) * (~x)^k * (~y)^((~n)-k) for k in 0:Int(~n))) -] - - -_expand_trig = [ - @rule(sin(~c::istwo * ~a) => 2sin(~a)*cos(~a)) - @rule sin(~a + ~b) => sin(~a)*cos(~b) + cos(~a)*sin(~b) - @rule cos(~c::istwo * ~a) => cos(~a)^2 - sin(~a)^2 - @rule cos(~a + ~b) => cos(~a)*cos(~b) - sin(~a)*sin(~b) - @rule sec(~a) => 1 / cos(~a) - @rule csc(~a) => 1 / sin(~a) - @rule tan(~a) => sin(~a)/cos(~a) - @rule cot(~a) => cos(~a)/sin(~a) -] - - -_expand_power = [ - @rule (~x)^(~a+~b) => (~x)^(~a) * (~x)^(~b) - @rule((~x*~y)^(~a) => (~x)^(~a) * (~y)^(~a)) -] - -_expand_log = [ - @rule(log(~x*~y) => log(~x) + log(~y)) - @rule log((~x)^(~n)) => ~n * log(~x) -] - -_expand_misc = [ - @rule( -(~a) => (-1)*~a) - @rule(((~c)::isone/~a) * ~a => one(~a)) - @rule(~a * ((~c)::isone/~a) => one(~a)) - #@rule(/(~a,~b) => *(~a, (~b)^(-1))) -] - - -# make methods for expressions -function simplify(ex::SymbolicExpression) - ex = rewrite(ex, CANONICALIZE) - theories = (PLUS_DISTRIBUTE, - POW_RULES, - ASSORTED_RULES, - TRIG_RULES, - EXP_RULES, - LOG_RULES - ) - ex = rewrite(ex, reduce(∪, theories)) - ex = rewrite(ex, CANONICALIZE) -end - -function expand(ex::SymbolicExpression) - theories = ( - _expand_minus, - _expand_distributive, _expand_binom, _expand_trig, - _expand_power, _expand_log, - _expand_misc - ) - ex = rewrite(ex, reduce(∪,theories)) -end - - -# -canonicalize(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE) -powsimp(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ POW_RULES ∪ EXP_RULES) -trigsimp(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ TRIG_RULES) -logcombine(ex::SymbolicExpression) = rewrite(ex, CANONICALIZE ∪ LOG_RULES) - -expand_trig(ex::SymbolicExpression) = rewrite(ex, _expand_trig) -expand_power_exp(ex::SymbolicExpression) = rewrite(ex, _expand_power) -expand_log(ex::SymbolicExpression) = rewrite(ex, _expand_log) - - -# function to run quickly and make terms more nice -# used by show -function _canon(x) - rules = [ - # sort - @rule(~x::needs_sorting₊ => sort_args(+, ~x)) # also merge - @rule(~x::needs_sortingₓ => sort_args(*, ~x)) - - # combine terms - @rule(~x + ~x => 2x) - @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)) - @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) - @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) - - # additive identity - @acrule(~z::iszero + ~x => ~x) - @acrule(+(~z::iszero, ~~xs...) => +(~~xs...)) - - # multiplicative zero - @ordered_acrule(~z::iszero * ~x => zero(~x)) - @ordered_acrule(*(~z::iszero, ~~xs...) => zero(~z)) - - # multiplicative identity - @acrule(~z::isone * ~x => ~x) - @acrule(*(~z::isone, ~~xs...) => *(~~xs...)) - - - ] - - rewrite(x, rules) -end diff --git a/src/ops.jl b/src/ops.jl index 6a5766e..a32cbe2 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -6,7 +6,7 @@ import SimpleExpressions: AbstractSymbolic, SymbolicExpression, D import SimpleExpressions.CallableExpressions: StaticExpression function Base.fourthroot(x::AbstractSymbolic) u = StaticExpression((x,), fourthroot) - SymbolicExpression(u) + SymbolicExpression(u) end D(::typeof(fourthroot), args,x) = (𝑥 = only(args); D(𝑥,x) * fourthroot(x)^3 / 4) @@ -14,7 +14,7 @@ fourthroot(x^2 + 2) ``` =# ## ---- operations -for op ∈ (://, :^, :≈) +for op ∈ (://, :≈) @eval begin import Base: $op Base.$op(x::AbstractSymbolic, y::AbstractSymbolic) = @@ -24,35 +24,68 @@ for op ∈ (://, :^, :≈) end end -for op ∈ (:/, ) +# commutative ops +for op ∈ (:*, :+) @eval begin import Base: $op - Base.$op(x::AbstractSymbolic, y::AbstractSymbolic) = - SymbolicExpression(StaticExpression((↓(x), ↓(y)), $op)) - Base.$op(x::AbstractSymbolic, y::Number) = _isunit(*,y) ? x : $op(promote(x,y)...) + Base.$op(x::SymbolicNumber, y::SymbolicNumber) =$op(x(), y()) + Base.$op(x::AbstractSymbolic, y::Number) = $op(promote(x,y)...) Base.$op(x::Number, y::AbstractSymbolic) = $op(promote(x,y)...) end end -## arrange for *, + to be n-ary -_isunit(::typeof(+), y::Number) = iszero(y) -_isunit(::typeof(*), y::Number) = isone(y) - -for op ∈ (:*, :+) +for op ∈ (:/, :^) @eval begin import Base: $op - Base.$op(x::AbstractSymbolic, y::AbstractSymbolic) = - SymbolicExpression(StaticExpression(TupleTools.vcat(_arguments($op,x), _arguments($op,y)), $op)) - Base.$op(x::AbstractSymbolic, y::Number) = _isunit($op, y) ? x : $op(promote(x,y)...) - Base.$op(x::Number, y::AbstractSymbolic) = _isunit($op, x) ? y : $op(promote(x,y)...) + Base.$op(x::AbstractSymbolic, y::Number) = $op(promote(x,y)...) + Base.$op(x::Number, y::AbstractSymbolic) = $op(promote(x,y)...) + end end - Base.:-(x::AbstractSymbolic, y::AbstractSymbolic) = x + (-1)*y Base.:-(x::AbstractSymbolic, y::Number) = x + (-1)*y Base.:-(x::Number, y::AbstractSymbolic) = x + (-1)*y +# do some light simplification on construction +# ADD +function Base.:+(x::AbstractSymbolic, y::AbstractSymbolic) + iszero(x) && return y + iszero(y) && return x + as = TupleTools.vcat(_arguments(+,x), _arguments(+,y)) + SymbolicExpression(StaticExpression(as, +)) +end + +# MUL +function Base.:*(x::AbstractSymbolic, y::AbstractSymbolic) + isone(x) && return y + isone(y) && return x + iszero(x) && return zero(x) + iszero(y) && return zero(y) + as = TupleTools.vcat(_arguments(*,x), _arguments(*,y)) + SymbolicExpression(StaticExpression(as, *)) +end + +# DIV +function Base.:/(x::AbstractSymbolic, y::AbstractSymbolic) + x == y && return one(x) + isone(y) && return x + iszero(x) && return zero(x) + !isinf(x) && isinf(y) && return zero(x) + cs = (↓(x), ↓(y)) + SymbolicExpression(StaticExpression(cs, /)) +end + +## POW +function Base.:^(x::AbstractSymbolic, y::AbstractSymbolic) + iszero(y) && return one(x) # 0^0 is 1 + iszero(x) && return zero(y) + isone(x) && return x + + cs = (↓(x), ↓(y)) + SymbolicExpression(StaticExpression(cs, ^)) +end + 𝑄 = Union{Integer, Rational} for op ∈ (:+, :-, :*, :^) @@ -65,11 +98,14 @@ end for op ∈ (:/, ://) @eval begin - Base.$op(x::SymbolicNumber{DynamicConstant{T}}, + function Base.$op(x::SymbolicNumber{DynamicConstant{T}}, y::SymbolicNumber{DynamicConstant{S}}) where { - T<:𝑄, S<:𝑄} = - SymbolicNumber(x()//y()) + T<:𝑄, S<:𝑄} + u,v = x(), y() + isone(v) && return SymbolicNumber(u) + SymbolicNumber(u//v) end + end end for op ∈ (:+, :-, :*, :^, :/ ) @@ -232,13 +268,14 @@ for op ∈ (:zip, :getindex,) end end -## special cases +## ---- special cases +## log Base.log(a::Number, x::AbstractSymbolic) = log(x) / log(symbolicnumber(a)) function Base.broadcasted(::typeof(log), a, b::AbstractSymbolic) SymbolicExpression(Base.broadcasted, (log, a, b)) end - +## ---- powers Base.inv(a::AbstractSymbolic) = SymbolicExpression(inv, (a,)) Base.inv(a::SymbolicExpression) = _inv(operation(a), a) _inv(::typeof(inv), a) = only(arguments(a)) @@ -249,6 +286,7 @@ function _inv(::typeof(^), a) end _inv(::Any, a) = SymbolicExpression(inv, (a,)) +## ---- literal_pow ## handle integer powers Base.literal_pow(::typeof(^), x::AbstractSymbolic, ::Val{0}) = one(x) Base.literal_pow(::typeof(^), x::AbstractSymbolic, ::Val{1}) = x @@ -261,74 +299,9 @@ function Base.literal_pow(::typeof(^), x::AbstractSymbolic, ::Val{p}) where {p} u = SymbolicExpression(^, (x, p′)) p < 0 ? 1 / u : u end -# broadcast + function Base.broadcasted(::typeof(Base.literal_pow), u, a::AbstractSymbolic, p::Val{N}) where {N} SymbolicExpression(Base.broadcasted, (^, a,N)) end - -# simplifying operations -# XXX These are really in need of removal -## plus -⊕(x::SymbolicNumber,y::SymbolicNumber) = SymbolicNumber(x() + y()) -function ⊕(x,y) - iszero(x) && return y - iszero(y) && return x - return x + y -end - -## minus -⊖(x::SymbolicNumber,y::SymbolicNumber) = SymbolicNumber(x() - y()) -function ⊖(x,y) - iszero(x) && return -y - iszero(y) && return x - return x - y -end - - -## times -⊗(x::SymbolicNumber,y::SymbolicNumber) = SymbolicNumber(x() * y()) -function ⊗(x,y) - isone(x) && return y - isone(y) && return x - iszero(x) && return zero(x) - iszero(y) && return zero(y) - return x * y -end - -## div -function ⨸(x::SymbolicNumber,y::SymbolicNumber) - n, d = x(), y() - # keep as rational? - isa(n, Integer) && isa(d, Integer) && return SymbolicNumber(n // d) - SymbolicNumber(x() / y()) -end - -function ⨸(x,y) - x == y && return one(x) - isone(y) && return x - iszero(x) && return zero(x) - !isinf(x) && isinf(y) && return zero(x) - - - - # can cancel? - if is_operation(/)(y) - a, b = arguments(y) - return (x ⊗ b) ⨸ a - end - - if is_operation(*)(x) - if contains(x, y) # cancel y in x; return - out = one(x) - for c ∈ sorted_arguments(x) - c == y && continue - out = out ⊗ c - end - return out - end - end - - return x / y -end diff --git a/src/scalar-derivative.jl b/src/scalar-derivative.jl index 74b8a5e..a4f3896 100644 --- a/src/scalar-derivative.jl +++ b/src/scalar-derivative.jl @@ -1,3 +1,13 @@ +## could name this diff(ex, x) +## Maple diff +## Matlab diff +## Mathematica derivative +## SymPy diff +## Sage derivative +## Polynomials derivative +## Symbolics Differential Differential(x) = Base.Fix2(SimpleExpressions.D,x) +## Symbolics derivative + """ D(::AbstractSymbolic, [x]) @@ -28,6 +38,8 @@ D(𝑥::SymbolicVariable, x) = 𝑥 == x ? 1 : 0 D(𝑥::SymbolicParameter, x) = 𝑥 == x ? 1 : 0 D(ex::SymbolicEquation, x) = D(ex.lhs, x) ~ D(ex.rhs, x) +# combine slows this down +#D(ex::SymbolicExpression, x) = combine(D(operation(ex), arguments(ex), x)) D(ex::SymbolicExpression, x) = D(operation(ex), arguments(ex), x) @@ -51,13 +63,13 @@ D(ex::SymbolicEquation) = D(ex.lhs) ~ D(ex.rhs) # cases ## sum rule function D(::typeof(+), args, x) - reduce(⊕, D.(args, x); init=zero(x)) + reduce(+, D.(args, x); init=zero(x)) end D(::typeof(sum), args, x) = SymbolicExpression(+, D.(args), x) ## difference rule function D(::typeof(-), args, x) - return reduce(⊖, D.(args, x); init=zero(x)) + return reduce(-, D.(args, x); init=zero(x)) end ## product rule @@ -70,7 +82,7 @@ function D(::typeof(*), args, x) aa[i] = a end aa[i] = ai′ - tot = tot ⊕ reduce(⊗, aa) + tot = tot + reduce(*, aa) end return tot end @@ -80,7 +92,7 @@ D(::typeof(prod), args, x) = D(SymbolicExpression(*, args), x) function D(::typeof(/), args, x) u,v = args u′, v′ = D(u,x), D(v,x) - ((u′ ⊗ v) ⊖ (u ⊗ v′)) ⨸ (v⊗v) + ((u′ * v) - (u * v′)) / (v*v) end ## powers @@ -89,9 +101,9 @@ function D(::typeof(^), args,x) if !contains(b, x) iszero(b) && return zero(x) - isone(b) && return D(a,x) ⊗ a - isone(b-1) && return D(a,x) ⊗ (2*a) - return D(a,x) ⊗ (b*a^(b()-1)) + isone(b) && return D(a,x) * a + isone(b-1) && return D(a,x) * (2*a) + return D(a,x) * (b*a^(b()-1)) else return D(exp(b * log(a)),x) end @@ -103,46 +115,46 @@ end D(::typeof(sqrt), args,x) = (𝑥 = only(args); D(𝑥,x) / sqrt(𝑥) * (1//2)) D(::typeof(cbrt), args,x) = (𝑥 = only(args); D(𝑥,x) / cbrt(𝑥)^2 * (1//3)) -D(::typeof(inv), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -1/𝑥^2 ⊗ 𝕀(Ne(𝑥,0))) -D(::typeof(abs), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ sign(𝑥) ⊗ 𝕀(Ne(𝑥, 0))) -D(::typeof(sign), args,x) = (𝑥 = only(args); 0 ⊗ 𝕀(𝑥 != 0)) -D(::typeof(abs2), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ 2𝑥) -D(::typeof(deg2rad), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (pi / 180)) -D(::typeof(rad2deg), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (180 / pi)) - -D(::typeof(exp), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ exp(𝑥)) -D(::typeof(exp2), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ exp2(𝑥) ⊗ log(2)) -D(::typeof(exp10), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ exp10(𝑥) ⊗ log(10)) -D(::typeof(expm1), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ exp(𝑥)) -D(::typeof(log), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (1/𝑥) ⊗ 𝕀(Ge(𝑥,0))) -D(::typeof(log2), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (1/𝑥/log(2)) ⊗ 𝕀(Ge(𝑥, 0))) -D(::typeof(log10), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (1/𝑥/log(10)) ⊗ 𝕀(Ge(𝑥, 0))) -D(::typeof(log1p), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ 1/(1 + 𝑥)) - - -D(::typeof(sin), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ cos(𝑥)) -D(::typeof(cos), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -sin(𝑥)) -D(::typeof(tan), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ sec(𝑥)^2) -D(::typeof(sec), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ sec(𝑥) ⊗ tan(𝑥)) -D(::typeof(csc), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -csc(𝑥) ⊗ cot(𝑥)) -D(::typeof(cot), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -csc(𝑥)^2) +D(::typeof(inv), args,x) = (𝑥 = only(args); D(𝑥,x) * -1/𝑥^2 * 𝕀(Ne(𝑥,0))) +D(::typeof(abs), args,x) = (𝑥 = only(args); D(𝑥,x) * sign(𝑥) * 𝕀(Ne(𝑥, 0))) +D(::typeof(sign), args,x) = (𝑥 = only(args); 0 * 𝕀(𝑥 != 0)) +D(::typeof(abs2), args,x) = (𝑥 = only(args); D(𝑥,x) * 2𝑥) +D(::typeof(deg2rad), args,x) = (𝑥 = only(args); D(𝑥,x) * (pi / 180)) +D(::typeof(rad2deg), args,x) = (𝑥 = only(args); D(𝑥,x) * (180 / pi)) + +D(::typeof(exp), args,x) = (𝑥 = only(args); D(𝑥,x) * exp(𝑥)) +D(::typeof(exp2), args,x) = (𝑥 = only(args); D(𝑥,x) * exp2(𝑥) * log(2)) +D(::typeof(exp10), args,x) = (𝑥 = only(args); D(𝑥,x) * exp10(𝑥) * log(10)) +D(::typeof(expm1), args,x) = (𝑥 = only(args); D(𝑥,x) * exp(𝑥)) +D(::typeof(log), args,x) = (𝑥 = only(args); D(𝑥,x) * (1/𝑥) * 𝕀(Ge(𝑥,0))) +D(::typeof(log2), args,x) = (𝑥 = only(args); D(𝑥,x) * (1/𝑥/log(2)) * 𝕀(Ge(𝑥, 0))) +D(::typeof(log10), args,x) = (𝑥 = only(args); D(𝑥,x) * (1/𝑥/log(10)) * 𝕀(Ge(𝑥, 0))) +D(::typeof(log1p), args,x) = (𝑥 = only(args); D(𝑥,x) * 1/(1 + 𝑥)) + + +D(::typeof(sin), args,x) = (𝑥 = only(args); D(𝑥,x) * cos(𝑥)) +D(::typeof(cos), args,x) = (𝑥 = only(args); D(𝑥,x) * -sin(𝑥)) +D(::typeof(tan), args,x) = (𝑥 = only(args); D(𝑥,x) * sec(𝑥)^2) +D(::typeof(sec), args,x) = (𝑥 = only(args); D(𝑥,x) * sec(𝑥) * tan(𝑥)) +D(::typeof(csc), args,x) = (𝑥 = only(args); D(𝑥,x) * -csc(𝑥) * cot(𝑥)) +D(::typeof(cot), args,x) = (𝑥 = only(args); D(𝑥,x) * -csc(𝑥)^2) D(::typeof(asin), args,x) = (𝑥 = only(args); D(𝑥,x) / sqrt(1 - 𝑥^2)) D(::typeof(acos), args,x) = (𝑥 = only(args); D(𝑥,x) / (-sqrt(1 - 𝑥^2))) D(::typeof(atan), args,x) = (𝑥 = only(args); D(𝑥,x) / (1 + 𝑥^2)) -D(::typeof(asec), args,x) = (𝑥 = only(args); D(𝑥,x) / (abs(𝑥) ⊗ sqrt(𝑥^2 - 1))) -D(::typeof(acsc), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ (abs(𝑥) ⊗ sqrt(𝑥^2 - 1)) ⊗ (-1)) -D(::typeof(acot), args,x) = (𝑥 = only(args); D(𝑥,x) / (1 + 𝑥^2) ⊗ (-1)) - -D(::typeof(sinh), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ cosh(𝑥)) -D(::typeof(cosh), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ sinh(𝑥)) -D(::typeof(tanh), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ sech(𝑥)^2) -D(::typeof(sech), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -sech(𝑥) ⊗ tanh(𝑥)) -D(::typeof(csch), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -csch(𝑥) ⊗ coth(𝑥)) -D(::typeof(coth), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -csch(𝑥)^2) - -D(::typeof(sinpi), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ π ⊗ cospi(𝑥)) -D(::typeof(cospi), args,x) = (𝑥 = only(args); D(𝑥,x) ⊗ -π ⊗ sinpi(𝑥)) +D(::typeof(asec), args,x) = (𝑥 = only(args); D(𝑥,x) / (abs(𝑥) * sqrt(𝑥^2 - 1))) +D(::typeof(acsc), args,x) = (𝑥 = only(args); D(𝑥,x) * (abs(𝑥) * sqrt(𝑥^2 - 1)) * (-1)) +D(::typeof(acot), args,x) = (𝑥 = only(args); D(𝑥,x) / (1 + 𝑥^2) * (-1)) + +D(::typeof(sinh), args,x) = (𝑥 = only(args); D(𝑥,x) * cosh(𝑥)) +D(::typeof(cosh), args,x) = (𝑥 = only(args); D(𝑥,x) * sinh(𝑥)) +D(::typeof(tanh), args,x) = (𝑥 = only(args); D(𝑥,x) * sech(𝑥)^2) +D(::typeof(sech), args,x) = (𝑥 = only(args); D(𝑥,x) * -sech(𝑥) * tanh(𝑥)) +D(::typeof(csch), args,x) = (𝑥 = only(args); D(𝑥,x) * -csch(𝑥) * coth(𝑥)) +D(::typeof(coth), args,x) = (𝑥 = only(args); D(𝑥,x) * -csch(𝑥)^2) + +D(::typeof(sinpi), args,x) = (𝑥 = only(args); D(𝑥,x) * π * cospi(𝑥)) +D(::typeof(cospi), args,x) = (𝑥 = only(args); D(𝑥,x) * -π * sinpi(𝑥)) ## more in SpecialFunctions.jl extension diff --git a/src/simplify.jl b/src/simplify.jl index e69de29..61e58d5 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -0,0 +1,27 @@ +## ----- Interface +""" + simplify(ex) + +Simplify expression using `Metatheory.jl` and rules on loan from `SymbolicUtils.jl`. +""" +function simplify() +end + +""" + expand(ex) + +Expand terms in an expression using `Metatheory.jl` +""" +function expand() +end + +# some default definitions +# we extend to SymbolicExpression in the Metatheroy extension +for fn ∈ (:simplify, :expand, + :canonicalize, :powsimp, :trigsimp, :logcombine, + :expand_trig, :expand_power_exp, :expand_log) + @eval begin + $fn(ex::AbstractSymbolic) = ex + $fn(eq::SymbolicEquation) = SymbolicEquation($fn.(eq)...) + end +end diff --git a/src/solve.jl b/src/solve.jl index 1305c9e..74093d5 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -63,7 +63,7 @@ function _solve(l, r, x::𝑉) if contains(l, x) l, r = isolate_x(Val(:→), l, r, x) else - l, r = zero(l), r ⊖ l + l, r = zero(l), r - l end l == l′ && return _final_solve(l, r, x) _solve(l, r, x) # recurse @@ -83,11 +83,11 @@ function _final_solve(l,r,x) return l ~ r elseif length(cs) == 2 a0,a1 = cs - return x ~ _combine_numbers((r ⊖ a0) ⨸ a1) + return x ~ _combine_numbers((r - a0) / a1) end p = sum(aᵢ * x^i for (i, aᵢ) ∈ enumerate(Iterators.rest(cs,2))) # could solve, but ... - return p ~ _combine_numbers(r ⊖ first(cs)) + return p ~ _combine_numbers(r - first(cs)) end l ~ _combine_numbers(r) end @@ -148,7 +148,7 @@ function coefficients(ex, x) d = Dict{Any, Any}() for c in cs (aᵢ, i) = _monomial(c, x) - d[i] = aᵢ ⊕ get(d, i, zero(x)) + d[i] = aᵢ + get(d, i, zero(x)) end n = maximum(collect(keys(d))) @@ -179,7 +179,7 @@ function _monomial(c, x) if is_operation(*)(c) ps = _monomial.(arguments(c), x) - aᵢ = reduce(⊗, first.(ps), init=one(x)) + aᵢ = reduce(*, first.(ps), init=one(x)) i = sum(last.(ps)) return (aᵢ, i) @@ -207,7 +207,7 @@ end # a*(b+c) --> a*b + a*c (flatten?) # work of distribute_over_plus is op by op function _distribute_over_plus(::typeof(+), ex, x) - reduce(⊕, _distribute_over_plus.(sorted_arguments(ex), x), init=zero(x)) + reduce(+, _distribute_over_plus.(sorted_arguments(ex), x), init=zero(x)) end function _distribute_over_plus(::typeof(*), ex, x) @@ -218,22 +218,22 @@ function _distribute_over_plus(::typeof(*), ex, x) b = c continue else - a = a ⊗ _distribute_over_plus(c, x) + a = a * _distribute_over_plus(c, x) end end isnothing(b) && return a - return mapreduce(Base.Fix1(⊗, a), ⊕, sorted_arguments(b), init=zero(x)) + return mapreduce(Base.Fix1(*, a), +, sorted_arguments(b), init=zero(x)) end function _distribute_over_plus(::typeof(-), ex, x) - reduce(⊖, _distribute_over_plus.(arguments(ex), x), init=zero(x)) + reduce(-, _distribute_over_plus.(arguments(ex), x), init=zero(x)) end function _distribute_over_plus(::typeof(/), ex, x) a, b = arguments(ex) contains(b, x) && return ex - a ⊗ (1 / b) + a * (1 / b) end function _distribute_over_plus(::typeof(^), ex, x) @@ -246,7 +246,7 @@ function _distribute_over_plus(::typeof(^), ex, x) n < 0 && return ex l = one(x) for i in 1:n - l = l ⊗ a + l = l * a end return l end @@ -261,12 +261,12 @@ _combine_numbers(ex) = _combine_numbers(operation(ex), ex) function _combine_numbers(::typeof(+), ex) args = _combine_numbers.(sorted_arguments(ex)) - foldl(⊕, args, init=zero(ex)) + foldl(+, args, init=zero(ex)) end function _combine_numbers(::typeof(*), ex) args = _combine_numbers.(sorted_arguments(ex)) - foldl(⊗, args, init=one(ex)) + foldl(*, args, init=one(ex)) end function _combine_numbers(::Any, ex) @@ -282,7 +282,7 @@ end function isolate_x(::Val{:←}, l, r::𝑉, x) if r == x - l = l ⨸ r + l = l / r r = one(x) end l, r @@ -299,17 +299,17 @@ function isolate_x(::Val{:→}, ::typeof(/), l, r, x) if contains(a, x) l′ = a else - r = r ⨸ a + r = r / a end if contains(b, x) if !contains(l′, x) # take reciprocal - l′ = b ⨸ l′ - r = one(x) ⨸ r + l′ = b / l′ + r = one(x) / r else - l′ = l′ ⨸ b + l′ = l′ / b end else - r = r ⊗ b + r = r * b end l′, r @@ -325,7 +325,7 @@ function isolate_x(::Val{:←}, ::typeof(/), l, r, x) end if contains(b, x) - l = l ⊗ b + l = l * b else r′ = r′ / b end @@ -339,15 +339,15 @@ function isolate_x(::Val{:→}, ::typeof(-), l, r, x) a, b, = arguments(l) l′ = zero(l) if !contains(a, x) - r = r ⊖ a + r = r - a else l′ = a end if !contains(b, x) - r = r ⊕ b + r = r + b else - l′ = l′ ⊖ b + l′ = l′ - b end l, r′ @@ -357,15 +357,15 @@ function isolate_x(::Val{:←}, ::typeof(-), l, r, x) a, b, = arguments(r) r′ = zero(r) if contains(a, x) - l = l ⊖ a + l = l - a else r′ = a end if contains(b, x) - l = l ⊕ b + l = l + b else - r′ = r′ ⊖ b + r′ = r′ - b end l, r′ @@ -383,7 +383,7 @@ function isolate_x(::Val{:→}, ::typeof(^), l, r, x) bb == 2 && return a, sqrt(r) bb == 3 && return a, cbrt(r) end - l,r = a, r^(one(x) ⨸ b) + l,r = a, r^(one(x) / b) end return l, r end @@ -410,9 +410,9 @@ function isolate_x(::Val{:→}, ::typeof(+), l, r, x) l′ = zero(l) for c ∈ arguments(l) if contains(c, x) - l′ = l′ ⊕ c + l′ = l′ + c else - r = r ⊖ c + r = r - c end end return l′, r @@ -423,9 +423,9 @@ function isolate_x(::Val{:←}, ::typeof(+), l, r, x) r′ = zero(r) for c ∈ arguments(r) if contains(c, x) - l = l ⊖ c + l = l - c else - r′ = r′ ⊕ c + r′ = r′ + c end end l, r′ @@ -437,9 +437,9 @@ function isolate_x(::Val{:→}, ::typeof(*), l, r, x) l′ = one(l) for c ∈ arguments(l) if contains(c, x) - l′ = l′ ⊗ c + l′ = l′ * c else - r = r ⨸ c + r = r / c end end l′, r @@ -449,9 +449,9 @@ function isolate_x(::Val{:←}, ::typeof(*), l, r, x) r′ = one(r) for c ∈ arguments(r) if contains(c, x) - l = l ⨸ c + l = l / c else - r′ = r′ ⊗ c + r′ = r′ * c end end l, r′ diff --git a/test/basic_tests.jl b/test/basic_tests.jl index dc8e8c1..93da67f 100644 --- a/test/basic_tests.jl +++ b/test/basic_tests.jl @@ -1,6 +1,6 @@ # basics import SimpleExpressions.TermInterface: arguments, sorted_arguments -import SimpleExpressions: D, solve, coefficients +import SimpleExpressions: D, solve, coefficients, combine import SimpleExpressions: map_matched @testset "SimpleExpressions.jl" begin @@ -149,7 +149,6 @@ end - @testset "show" begin # test show # note *,+ do **not** do light simplification and sort arguments @@ -168,6 +167,35 @@ end end +@testset "combine" begin + # simplish simplification + @symbolic x + + ex = 2x + x + @test combine(ex) == 3x + @test combine(ex + 2x) == 5x + + ex = x * x^2 * x^3 + @test combine(ex) == x^6 + + ex = x * cos(x) * x^2 * x^3 + u = combine(ex) + @test u ∈ (x^6 * cos(x), cos(x)*x^6) + + ex = x + 2x + x*x*x + u = combine(ex) + @test u ∈ (3x + x^3, x^3 + 3x) + + ex = sum(n + n*x + n^2*x^2 for n in 1:5) + u = combine(ex) + @test coefficients(u, x) == (a₀ = 15, a₁ = 15, a₂ = 55) + + ex = sum(n + n*x + (n*x)^2 for n in 1:5) + u = combine(ex) + @test coefficients(u, x) == (a₀ = 15, a₁ = 15, a₂ = 55) + +end + @testset "broadcast/generators" begin @symbolic x p