Skip to content

Commit

Permalink
Also handle in-place case
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Dec 28, 2024
1 parent dbe40e0 commit 2631fdf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
23 changes: 18 additions & 5 deletions src/CodeGeneration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,27 @@ function make_function(func_array::AbstractArray{T}, input_variables::AbstractVe

_single_arg_function = @RuntimeGeneratedFunction(make_Expr(func_array, all_input_vars, in_place, init_with_zeros))
expected_input_sizes = map(size, input_variables)
function _maybe_multi_arg_function(input_varibles::AbstractVector...)
@boundscheck if !all(size(input) == expected_size for (input, expected_size) in zip(input_varibles, expected_input_sizes))
throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function."))
expected_result_size = size(func_array)
if in_place
compiled_function = function (result, input_variables::AbstractVector...)
@boundscheck if any(size(input) != expected_size for (input, expected_size) in zip(input_variables, expected_input_sizes))
throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function."))
end
@boundscheck if size(result) != expected_result_size
throw(ArgumentError("The result array must have the same size as the result of the function."))
end
return _single_arg_function(result, reduce(vcat, input_variables))
end
else
compiled_function = function (input_variables::AbstractVector...)
@boundscheck if any(size(input) != expected_size for (input, expected_size) in zip(input_variables, expected_input_sizes))
throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function."))
end
return _single_arg_function(reduce(vcat, input_variables))
end
return _single_arg_function(reduce(vcat, input_varibles))
end

return _maybe_multi_arg_function
return compiled_function
end
export make_function

8 changes: 7 additions & 1 deletion test/FDTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2168,13 +2168,19 @@ end
@testitem "multiarg code generation" begin
using FastDifferentiation: FastDifferentiation as FD

# out-of-place case
x = FD.make_variables(:x, 3)
y = FD.make_variables(:y, 3)
f = x .* y
f_callable = FD.make_function(f, x, y)

x_val = ones(3)
y_val = ones(3)
f_val = f_callable(x_val, y_val)
@test f_val ones(3)

# in-place case
result = zeros(3)
f_callable! = FD.make_function(f, x, y; in_place=true)
f_callable!(result, x_val, y_val)
@test result ones(3)
end

0 comments on commit 2631fdf

Please sign in to comment.