Skip to content

Commit

Permalink
allow different number of out-of-place args for isinplace
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Aug 30, 2024
1 parent 12f3030 commit 8b10979
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,18 @@ function Base.showerror(io::IO, e::FunctionArgumentsError)
end

"""
isinplace(f, inplace_param_number[,fname="f"])
isinplace(f, inplace_param_number, fname = "f", iip_preferred = true;
has_two_dispatches = true,
outofplace_param_number = inplace_param_number - 1)
isinplace(f::AbstractSciMLFunction[, inplace_param_number])
Check whether a function operates in place by comparing its number of arguments
to the expected number. If `f` is an `AbstractSciMLFunction`, then the type
parameter is assumed to be correct and is used. Otherwise `inplace_param_number`
is checked against the methods table, where `inplace_param_number` is the number
of arguments for the in-place dispatch. The out-of-place dispatch is assumed
to have one less. If neither of these dispatches exist, an error is thrown.
to have `outofplace_param_number` parameters (one less than the inplace version
by default). If neither of these dispatches exist, an error is thrown.
If the error is thrown, `fname` is used to tell the user which function has the
incorrect dispatches.
Expand All @@ -238,22 +241,23 @@ form is disabled and the 2-argument signature is ensured to be matched.
# See also
- [`numargs`](@ref numargs)
- [`numargs`](@ref numargs)
"""
function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true;
has_two_dispatches = true, isoptimization = false)
has_two_dispatches = true, isoptimization = false,
outofplace_param_number = inplace_param_number - 1)
nargs = numargs(f)
iip_dispatch = any(x -> x == inplace_param_number, nargs)
oop_dispatch = any(x -> x == inplace_param_number - 1, nargs)
oop_dispatch = any(x -> x == outofplace_param_number, nargs)

if length(nargs) == 0
throw(NoMethodsError(fname))
end

if !iip_dispatch && !oop_dispatch && !isoptimization
if all(x -> x > inplace_param_number, nargs)
if all(>(inplace_param_number), nargs)
throw(TooManyArgumentsError(fname, f))
elseif all(x -> x < inplace_param_number - 1, nargs) && has_two_dispatches
elseif all(<(outofplace_param_number), nargs) && has_two_dispatches
# Possible extra safety?
# Find if there's a `f(args...)` dispatch
# If so, no error
Expand Down Expand Up @@ -283,7 +287,7 @@ function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true;
# Find if there's a `f(args...)` dispatch
# If so, no error
for i in 1:length(nargs)
if nargs[i] < inplace_param_number &&
if nargs[i] < outofplace_param_number &&
any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters)
# If varargs, assume iip
return iip_preferred
Expand Down
8 changes: 8 additions & 0 deletions test/function_building_error_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ f = Foo{1}()
(this::Foo{T})(args...) where {T} = 1
@test SciMLBase.isinplace(Foo{Int}(), 4)

@testset "isinplace accepts an out-of-place version with different numbers of parameters " begin
f1(u) = 2 * u
f2(du, u, x1, x2) = copyto!(du, u)
@test !isinplace(f1, 2)
@test_throws SciMLBase.TooFewArgumentsError SciMLBase.isinplace(f1, 4)
@test isinplace(f2, 4; outofplace_param_number = 1)
end

## Problem argument tests

ftoomany(u, p, t, x, y) = 2u
Expand Down

0 comments on commit 8b10979

Please sign in to comment.