Skip to content

Commit

Permalink
Merge pull request #105 from brianguenter/multiargcodegenerationPR103
Browse files Browse the repository at this point in the history
Fix multi-arg code generation
  • Loading branch information
brianguenter authored Jan 14, 2025
2 parents c089bd7 + fac65e3 commit 6cbe888
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FastDifferentiation"
uuid = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
authors = ["BrianGuenter"]
version = "0.4.2"
version = "0.4.3"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down
11 changes: 11 additions & 0 deletions docs/src/makefunction.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ julia> f_exe!([2.0,3.0])
9.0
```

Example: generated function accepts more than one vector of input arguments. This is useful if you want to segregate your input variables into groups.
```julia
x = make_variables(:x, 3)
y = make_variables(:y, 3)
f = x .* y
f_callable = make_function(f, x, y)
x_val = ones(3)
y_val = ones(3)
f_val = f_callable(x_val, y_val) #executable takes two 3-vectors as input arguments
```

Example: assume your result is a large dense array (> 100 elements) and that you are using an in place array with no initialization. For dense arrays this should generate the fastest code:
```
julia> @variables x y z
Expand Down
151 changes: 116 additions & 35 deletions src/CodeGeneration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ end
```
and the second return value will be the constant value.
"""
function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_to_var::Union{Nothing,IdDict{Node,Union{Symbol,Real,Expr}}}=nothing, body::Expr=Expr(:block))
function function_body!(dag::Node, variable_to_index::IdDict{Node,Union{Expr,Int64}}, node_to_var::Union{Nothing,IdDict{Node,Union{Symbol,Real,Expr}}}=nothing, body::Expr=Expr(:block))
if node_to_var === nothing
node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}()
end
Expand All @@ -108,7 +108,7 @@ end

function undef_array_declaration(::StaticArray{S,<:Any,N}) where {S,N}
#need to initialize array to zero because this is no longer being done by simple assignment statements.
:(result = MArray{$(S),promote_type(result_element_type, eltype(input_variables)),$N}(undef))
:(result = MArray{$(S),result_element_type,$N}(undef))
end

"""
Expand Down Expand Up @@ -148,7 +148,68 @@ function to_number(func_array::SparseMatrixCSC{T}) where {T<:Node}
return tmp
end

function variable_names(input_variables::NTuple{N,AbstractVector}) where {N}
input_variable_names = Symbol[]
node_to_index = IdDict{Node,Union{Expr,Int64}}()

for (j, input_var_array) in pairs(input_variables)
var_name = Symbol("input_variables$j")
push!(input_variable_names, var_name)
for (i, node) in pairs(input_var_array)
node_to_index[node] = :($var_name[$i])
end
end

return input_variable_names, node_to_index
end

function return_array_type!(body::Expr, func_array, input_variable_names::AbstractVector, in_place::Bool)
# declare result element type, and result variable if not provided by the user
if in_place
return :(result_element_type = promote_type(eltype.(($(input_variable_names...),))...))
else
push!(body.args, :(result_element_type = promote_type($(_infer_numeric_eltype(func_array)), (eltype.(($(input_variable_names...),)))...)))
push!(body.args, undef_array_declaration(func_array))
end
end

function input_output_size_check(func_array, expected_input_lengths, input_variable_names, in_place)
input_check = :(
@boundscheck begin

lengths = $expected_input_lengths
for (input, expected_length) in zip(($((input_variable_names)...),), lengths)
if length(input) != expected_length
actual_lengths = map(length, ($(input_variable_names...),))
throw(ArgumentError("The input variables must have the same length as the input_variables argument to make_function. Expected lengths: $lengths. Actual lengths: $(actual_lengths)."))
end
end
end)


# wrap in function body
if in_place
expected_result_size = size(func_array)

return :(
begin
@boundscheck begin
expected_res_size = $expected_result_size
if size(result) != expected_res_size
throw(ArgumentError("The in place result array does not have the expected size. Expected size: $expected_res_size. Actual size: $(size(result))."))
end
end

$input_check
end)
else
return :(
begin
#here
$input_check
end)
end
end
"""
make_Expr(
func_array::AbstractArray{<:Node},
Expand All @@ -157,14 +218,15 @@ end
init_with_zeros::Bool
)
"""
function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector{S}, in_place::Bool, init_with_zeros::Bool) where {T<:Node,S<:Node}
function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector...; in_place::Bool=false, init_with_zeros::Bool=true) where {T<:Node}
node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}()
body = Expr(:block)

input_variable_names, node_to_index = variable_names(input_variables)

num_zeros = count(is_identically_zero, (func_array))
num_const = count((x) -> is_constant(x) && !is_identically_zero(x), func_array)


zero_threshold = 0.5
const_threshold = 0.5

Expand All @@ -182,9 +244,9 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector

# declare result element type, and result variable if not provided by the user
if in_place
push!(body.args, :(result_element_type = eltype(input_variables)))
push!(body.args, :(result_element_type = promote_type(eltype.(($(input_variable_names...),))...)))
else
push!(body.args, :(result_element_type = promote_type($(_infer_numeric_eltype(func_array)), eltype(input_variables))))
push!(body.args, :(result_element_type = promote_type($(_infer_numeric_eltype(func_array)), (eltype.(($(input_variable_names...),)))...)))
push!(body.args, undef_array_declaration(func_array))
end

Expand All @@ -195,11 +257,6 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector
push!(body.args, :(result .= $(to_number(func_array))))
end

node_to_index = IdDict{Node,Int64}()
for (i, node) in pairs(input_variables)
node_to_index[node] = i
end

for (i, node) in pairs(func_array)
# skip all terms that we have computed above during construction
if is_constant(node) && initialization_strategy === :const || # already initialized as constant above
Expand All @@ -212,23 +269,28 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector
end
push!(body.args, :(result[$i] = $variable))
end

expected_input_lengths = map(length, input_variables)
in_out_checks = input_output_size_check(func_array, expected_input_lengths, input_variable_names, in_place)
# return result or nothing if in_place
if in_place
push!(body.args, :(return nothing))
return :((result, $(input_variable_names...),) ->
begin
$in_out_checks
@inbounds begin
$body
end
end
)
else
push!(body.args, return_expression(func_array))
end

# wrap in function body
if in_place
return :((result, input_variables::AbstractArray) -> @inbounds begin
$body
end)
else
return :((input_variables::AbstractArray) -> @inbounds begin
$body
end)
return :(($(input_variable_names...),) ->
begin
$in_out_checks
@inbounds begin
$body
end
end)
end
end
export make_Expr
Expand All @@ -241,35 +303,37 @@ export make_Expr
)
`init_with_zeros` argument is not used for sparse matrices."""
function make_Expr(A::SparseMatrixCSC{T,Ti}, input_variables::AbstractVector{S}, in_place::Bool, init_with_zeros::Bool) where {T<:Node,S<:Node,Ti}
function make_Expr(A::SparseMatrixCSC{T,Ti}, input_variables::AbstractVector...; in_place::Bool=false, init_with_zeros::Bool=true) where {T<:Node,Ti}
rows = rowvals(A)
vals = nonzeros(A)
_, n = size(A)
body = Expr(:block)
node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}()

#need to understand how node_to_index is being used for sparse case


input_variable_names, node_to_index = variable_names(input_variables)

if !in_place #have to store the sparse vector indices in the generated code to know how to create sparsity pattern
push!(body.args, :(element_type = promote_type(Float64, eltype(input_variables))))
push!(body.args, :(element_type = promote_type(Float64, eltype.($(input_variable_names...))...)))
push!(body.args, :(result = SparseMatrixCSC($(A.m), $(A.n), $(A.colptr), $(A.rowval), zeros(element_type, $(length(A.nzval))))))
end

push!(body.args, :(vals = nonzeros(result)))



num_consts = count(x -> is_constant(x), vals)
if num_consts == nnz(A) #all elements are constant
push!(body.args, :(vals .= $(to_number(A))))
if in_place
return :((result, input_variables) -> $body)
return :((result, $(input_variable_names...)) -> $body)
else
push!(body.args, :(return result))
return :((input_variables) -> $body)
return :(($(input_variable_names...)) -> $body)
end
else
node_to_index = IdDict{Node,Int64}()
for (i, node) in pairs(input_variables)
node_to_index[node] = i
end

for j = 1:n
for i in nzrange(A, j)
node_body, variable = function_body!(vals[i], node_to_index, node_to_var)
Expand All @@ -284,10 +348,26 @@ function make_Expr(A::SparseMatrixCSC{T,Ti}, input_variables::AbstractVector{S},

push!(body.args, :(return result))

expected_input_lengths = map(length, input_variables)
in_out_checks = input_output_size_check(A, expected_input_lengths, input_variable_names, in_place)

if in_place
return :((result, input_variables::AbstractArray) -> $body)
return :((result, $(input_variable_names...),) ->
begin
$in_out_checks
@inbounds begin
$body
end
end
)
else
return :((input_variables::AbstractArray) -> $body)
return :(($(input_variable_names...),) ->
begin
$in_out_checks
@inbounds begin
$body
end
end)
end
end
end
Expand Down Expand Up @@ -368,7 +448,8 @@ function make_function(func_array::AbstractArray{T}, input_variables::AbstractVe

@assert length(temp) == 0 "The variables $temp were not in the input_variables argument to make_function. Every variable that is used in your function must have a corresponding entry in the input_variables argument."

@RuntimeGeneratedFunction(make_Expr(func_array, all_input_vars, in_place, init_with_zeros))
temp = make_Expr(func_array, input_variables..., in_place=in_place, init_with_zeros=init_with_zeros)
@RuntimeGeneratedFunction(temp)
end
export make_function

4 changes: 2 additions & 2 deletions src/ExpressionGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,11 @@ function to_string(a::Node)
end
end

function node_symbol(a::Node, variable_to_index::IdDict{Node,Int64})
function node_symbol(a::Node, variable_to_index::IdDict{Node,Union{Expr,Int64}})
if is_tree(a)
result = gensym() #create a symbol to represent the node
elseif is_variable(a)
result = :(input_variables[$(variable_to_index[a])])
result = :($(variable_to_index[a]))
else
result = value(a) #not a tree not a variable so is some kind of constant.
end
Expand Down
Loading

2 comments on commit 6cbe888

@brianguenter
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Patch release to make multiple input arguments work for code generation. Now this will work correctly:

  x = make_variables(:x, 3)
    y = make_variables(:y, 3)
    f = x .* y
    f_callable = make_function(f, x, y)
    x_val = ones(3)
    y_val = ones(3)
    f_val = f_callable(x_val, y_val)

Before the generated code would not address the input args correctly resulting in crashes.

See #105 (comment) for details.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123006

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.3 -m "<description of version>" 6cbe88881e3ac2f9146aca7c963be9e2b662dbc4
git push origin v0.4.3

Please sign in to comment.