Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly handle rational functions in HomotopyContinuation #3265

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ function MTK.HomotopyContinuationProblem(
return prob
end

function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing;
fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end
transformation = MTK.PolynomialTransformation(sys)
if transformation isa MTK.NotPolynomialError
return transformation
end
result = MTK.transform_system(sys, transformation)
result = MTK.transform_system(sys, transformation; fraction_cancel_fn)
if result isa MTK.NotPolynomialError
return result
end
Expand Down
63 changes: 44 additions & 19 deletions src/systems/nonlinear/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a
`PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot
be transformed.
"""
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation)
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation;
fraction_cancel_fn = simplify_fractions)
subrules = transformation.substitution_rules
dvs = unknowns(sys)
eqs = full_equations(sys)
Expand All @@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
return NotPolynomialError(
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)
end
num, den = handle_rational_polynomials(t, new_dvs)
num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn)
# make factors different elements, otherwise the nonzero factors artificially
# inflate the error of the zero factor.
if iscall(den) && operation(den) == *
Expand Down Expand Up @@ -492,43 +493,67 @@ $(TYPEDSIGNATURES)
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.

Keyword arguments:
- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns
a simplified symbolic quantity with common factors in the numerator and denominator are
cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to
`nothing` to improve performance on large polynomials at the cost of avoiding non-trivial
cancellation.
"""
function handle_rational_polynomials(x, wrt)
function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn)
n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn)
num, den = n1 * d2, d1 * n2
elseif (op == +) || (op == -)
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
if op == -
args[2] = -args[2]
end
for arg in args
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num = num * d + n * den
den *= d
end
elseif op == ^
base, pow = args
num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn)
num ^= pow
den ^= pow
elseif op == *
num = 1
den = 1
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num *= n
den *= d
end
else
return x, 1
error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.")
end

if fraction_cancel_fn !== nothing
expr = fraction_cancel_fn(num / den)
if iscall(expr) && operation(expr) == /
num, den = arguments(expr)
else
num, den = expr, 1
end
end

# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
Expand Down
11 changes: 7 additions & 4 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,16 @@ end

function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
parammap = DiffEqBase.NullParameters();
check_length = true, use_homotopy_continuation = true, kwargs...) where {iip}
check_length = true, use_homotopy_continuation = false, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
end
prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...)
if prob isa HomotopyContinuationProblem
return prob
if use_homotopy_continuation
prob = safe_HomotopyContinuationProblem(
sys, u0map, parammap; check_length, kwargs...)
if prob isa HomotopyContinuationProblem
return prob
end
end
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
Expand Down
33 changes: 31 additions & 2 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface
using SymbolicUtils
import ModelingToolkit as MTK
using LinearAlgebra
using Test
Expand Down Expand Up @@ -29,11 +30,13 @@ import HomotopyContinuation
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid)≈0.0 atol=1e-10

prob2 = NonlinearProblem(sys, u0)
prob2 = NonlinearProblem(sys, u0; use_homotopy_continuation = true)
@test prob2 isa HomotopyContinuationProblem
sol = solve(prob2; threading = false)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid)≈0.0 atol=1e-10

@test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem
end

struct Wrapper
Expand Down Expand Up @@ -217,7 +220,17 @@ end
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
@test_nowarn solve(prob; threading = false)
@test SciMLBase.successful_retcode(solve(prob; threading = false))
end

@testset "Rational function forced to common denominators" begin
@variables x = 1
@mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0)
sol = solve(prob; threading = false)
@test SciMLBase.successful_retcode(sol)
@test 1 / (1 + sol.u[1]) - sol.u[1]≈0.0 atol=1e-10
end
end

Expand All @@ -229,3 +242,19 @@ end
@test sol[x] ≈ √2.0
@test sol[y] ≈ sin(√2.0)
end

@testset "`fraction_cancel_fn`" begin
@variables x = 1
@named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) /
(x - 4)^3])
sys = complete(sys)

@testset "`simplify_fractions`" begin
prob = HomotopyContinuationProblem(sys, [])
@test prob.denominator([0.0], parameter_values(prob)) ≈ [4.0]
end
@testset "`nothing`" begin
prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing)
@test sort(prob.denominator([0.0], parameter_values(prob))) ≈ [2.0, 4.0^3]
end
end
Loading