From 2631fdfd3c6c39c3ac2e098ca58dbced79bd6a47 Mon Sep 17 00:00:00 2001 From: lassepe Date: Sat, 28 Dec 2024 19:04:05 +0100 Subject: [PATCH] Also handle in-place case --- src/CodeGeneration.jl | 23 ++++++++++++++++++----- test/FDTests.jl | 8 +++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index 73406a0..007bdb8 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -371,14 +371,27 @@ function make_function(func_array::AbstractArray{T}, input_variables::AbstractVe _single_arg_function = @RuntimeGeneratedFunction(make_Expr(func_array, all_input_vars, in_place, init_with_zeros)) expected_input_sizes = map(size, input_variables) - function _maybe_multi_arg_function(input_varibles::AbstractVector...) - @boundscheck if !all(size(input) == expected_size for (input, expected_size) in zip(input_varibles, expected_input_sizes)) - throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function.")) + expected_result_size = size(func_array) + if in_place + compiled_function = function (result, input_variables::AbstractVector...) + @boundscheck if any(size(input) != expected_size for (input, expected_size) in zip(input_variables, expected_input_sizes)) + throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function.")) + end + @boundscheck if size(result) != expected_result_size + throw(ArgumentError("The result array must have the same size as the result of the function.")) + end + return _single_arg_function(result, reduce(vcat, input_variables)) + end + else + compiled_function = function (input_variables::AbstractVector...) + @boundscheck if any(size(input) != expected_size for (input, expected_size) in zip(input_variables, expected_input_sizes)) + throw(ArgumentError("The input variables must have the same size as the input_variables argument to make_function.")) + end + return _single_arg_function(reduce(vcat, input_variables)) end - return _single_arg_function(reduce(vcat, input_varibles)) end - return _maybe_multi_arg_function + return compiled_function end export make_function diff --git a/test/FDTests.jl b/test/FDTests.jl index 26751a3..0694362 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -2168,13 +2168,19 @@ end @testitem "multiarg code generation" begin using FastDifferentiation: FastDifferentiation as FD + # out-of-place case x = FD.make_variables(:x, 3) y = FD.make_variables(:y, 3) f = x .* y f_callable = FD.make_function(f, x, y) - x_val = ones(3) y_val = ones(3) f_val = f_callable(x_val, y_val) @test f_val ≈ ones(3) + + # in-place case + result = zeros(3) + f_callable! = FD.make_function(f, x, y; in_place=true) + f_callable!(result, x_val, y_val) + @test result ≈ ones(3) end