diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 3bf9211e6..0e476c55a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -77,8 +77,9 @@ function compute_ydual_onearg( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,T,C} - (; xdual_tmp) = prep - make_dual!(T, xdual_tmp, x, tx) + # (; xdual_tmp) = prep + # make_dual!(T, xdual_tmp, x, tx) + xdual_tmp = make_dual(T, x, tx) # TODO: discuss reuse of mutable dual array ydual = f(xdual_tmp, map(unwrap, contexts)...) return ydual end diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f84666f44..a5db72551 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -3,6 +3,7 @@ Pkg.add("ForwardDiff") using ComponentArrays: ComponentArrays using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT using ForwardDiff: ForwardDiff using StaticArrays: StaticArrays using Test @@ -57,3 +58,19 @@ test_differentiation( ## Static test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) + +@testset verbose = true "No allocations on StaticArrays" begin + filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen + DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out + end + data = benchmark_differentiation( + AutoForwardDiff(), + filtered_static_scenarios; + benchmark=:prepared, + excluded=[:hessian, :pullback], # TODO: figure this out + logging=LOGGING, + ) + @testset "$(row[:scenario])" for row in eachrow(data) + @test row[:allocs] == 0 + end +end; diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index 0e7f02d4b..0b8d0d9d2 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -8,19 +8,10 @@ using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector mySArray(f::Function) = f -myMArray(f::Function) = f - mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T}) -myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(MVector{6,T}) - mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6}) -myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(MMatrix{2,3,T,6}) - mySArray(f::DIT.MultiplyByConstant) = f -myMArray(f::DIT.MultiplyByConstant) = f - mySArray(f::DIT.WritableClosure) = f -myMArray(f::DIT.WritableClosure) = f mySArray(x::Number) = x myMArray(x::Number) = x @@ -36,13 +27,8 @@ function myMArray(x::AbstractMatrix{T}) where {T} end mySArray(x::Tuple) = map(mySArray, x) -myMArray(x::Tuple) = map(myMArray, x) - mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x))) -myMArray(x::DI.Constant) = DI.Constant(myMArray(DI.unwrap(x))) - mySArray(::Nothing) = nothing -myMArray(::Nothing) = nothing function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f, x, y, tang, contexts, res1, res2) = scen @@ -57,22 +43,9 @@ function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} ) end -function myMArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, tang, contexts, res1, res2) = scen - return Scenario{op,pl_op,pl_fun}( - myMArray(f); - x=myMArray(x), - y=pl_fun == :in ? myMArray(y) : myMArray(y), - tang=myMArray(tang), - contexts=myMArray(contexts), - res1=myMArray(res1), - res2=myMArray(res2), - ) -end - function DIT.static_scenarios(args...; kwargs...) scens = DIT.default_scenarios(args...; kwargs...) - return vcat(mySArray.(scens), myMArray.(scens)) + return mySArray.(scens) end end diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index 74506155e..1803df72e 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -71,8 +71,8 @@ end ## Number to array -multiplicator(::Type{A}) where {A<:AbstractVector} = convert(A, float.(1:6)) -multiplicator(::Type{A}) where {A<:AbstractMatrix} = convert(A, reshape(float.(1:6), 2, 3)) +multiplicator(::Type{A}) where {A<:AbstractVector} = convert(A, 1:6) +multiplicator(::Type{A}) where {A<:AbstractMatrix} = convert(A, reshape(1:6, 2, 3)) struct NumToArr{A} end NumToArr(::Type{A}) where {A} = NumToArr{A}()