-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Next steps for differentiating #3
Comments
This is a performance consideration similar to how backwardsdiff should be more performant in N->1 maps. Since we have partial derivatives right now 1->1 maps are fastest with forwarddiff, but ideally you would somehow use backwarddiff if you were to know that you want to simulate the entire gradient. I believe that both cases will eventually be solvable by specialization on broadcasting. cf. https://github.com/FelixBenning/DifferentiableKernelFunctions.jl/blob/main/src/partial.jl#L57-L65 gradient(dim::Integer) = mappedarray(partial, Base.OneTo(dim))
hessian(dim::Integer) = mappedarray(partial, lazy_product(Base.OneTo(dim), Base.OneTo(dim)))
fullderivative(order::Integer,dim::Integer) = mappedarray(partial, lazy_product(ntuple(_->Base.OneTo(dim), order)...))
# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Int},1,T,typeof(partial)}
function extract_range(p_map::MappedPartialVec{T}) where {T<:AbstractUnitRange{Int}}
return p_map.data::T
end you would extract the original range from the lazy partial array and therefore know that you could use backwarddiff. If you similarly have a structure which says calculate all points and derivatives you could also use the "value+derivative" functionality. Tomorrow I'll think about the all derivatives up to order |
julia> DKF
DifferentiableKernelFunctions
julia> DKF.gradient(10)
10-element mappedarray(DifferentiableKernelFunctions.partial, ::Base.OneTo{Int64}) with eltype DifferentiableKernelFunctions.Partial{1, Int64}:
∂₁: Partial(1)
∂₂: Partial(2)
∂₃: Partial(3)
∂₄: Partial(4)
∂₅: Partial(5)
∂₆: Partial(6)
∂₇: Partial(7)
∂₈: Partial(8)
∂₉: Partial(9)
∂₁₀: Partial(10)
julia> dump(ans)
MappedArrays.ReadonlyMappedArray{DifferentiableKernelFunctions.Partial{1, Int64}, 1, Base.OneTo{Int64}, typeof(partial)}
f: partial (function of type typeof(partial))
data: Base.OneTo{Int64}
stop: Int64 10 this is also memory efficient |
honestly I would suggest we test which autodiff package works best and then just prescribe one until abstractdifferentiation is ready. I mean if we already did the "what works best" work, why should a user have to do it again? Sane defaults are a feature in my opinion. But an escape hatch is a nice feature and it would be good to keep that in mind for the architecture. |
This also makes sense to me. At the very least to start with, not point in trying to cover too much too soon. |
@Crown421 do you think it is worth thinking about input shapes other than vecs? Up until now I tried to do that but for gradients, hessians and its tuples this is becoming a headache |
It is difficult. I personally only work with vectors, but I know some people use a variety of inputs (including strings, interestingly). |
I mean here we only care about the indices and strings are actually linearly indexed. But Now Iterators.flatten(id, gradient, hessian) would pretty much do what we want. But it isn't an AbstractArray. So I'll probably do the same as I did with ProductArrays and write a flatten function which behaves just like
you could probably write a specialization of flatten which permutes a flattening of MappedArrays with identical map functions out into
these x = productArray(v, derivativeRequests) So we would want to specialize on const DiffPoints{T} = ProductArray{Tuple{T, MappedArray{partial, ?}} |
julia> Iterators.flatten((Ref(partial()), partial.(1:10), DKF.hessian(10))) |> collect
111-element Vector{DifferentiableKernelFunctions.Partial}:
id: Partial() a zero order derivative
∂₁: Partial(1)
∂₂: Partial(2)
∂₃: Partial(3)
∂₄: Partial(4)
∂₅: Partial(5)
∂₆: Partial(6)
∂₇: Partial(7)
∂₈: Partial(8)
∂₉: Partial(9)
∂₁₀: Partial(10)
∂₁∂₁: Partial(1,1)
∂₂∂₁: Partial(2,1)
∂₃∂₁: Partial(3,1)
∂₄∂₁: Partial(4,1)
∂₅∂₁: Partial(5,1)
∂₆∂₁: Partial(6,1)
∂₇∂₁: Partial(7,1)
∂₈∂₁: Partial(8,1)
∂₉∂₁: Partial(9,1)
∂₁₀∂₁: Partial(10,1)
∂₁∂₂: Partial(1,2)
∂₂∂₂: Partial(2,2)
⋮
∂₈∂₈: Partial(8,8)
∂₉∂₈: Partial(9,8)
∂₁₀∂₈: Partial(10,8)
∂₁∂₉: Partial(1,9)
∂₂∂₉: Partial(2,9)
∂₃∂₉: Partial(3,9)
∂₄∂₉: Partial(4,9)
∂₅∂₉: Partial(5,9)
∂₆∂₉: Partial(6,9)
∂₇∂₉: Partial(7,9)
∂₈∂₉: Partial(8,9)
∂₉∂₉: Partial(9,9)
∂₁₀∂₉: Partial(10,9)
∂₁∂₁₀: Partial(1,10)
∂₂∂₁₀: Partial(2,10)
∂₃∂₁₀: Partial(3,10)
∂₄∂₁₀: Partial(4,10)
∂₅∂₁₀: Partial(5,10)
∂₆∂₁₀: Partial(6,10)
∂₇∂₁₀: Partial(7,10)
∂₈∂₁₀: Partial(8,10)
∂₉∂₁₀: Partial(9,10)
∂₁₀∂₁₀: Partial(10,10) aaaand I just realized that Schwarz theorem should probably be considered ugh |
Aaaaand we have just ruled out ForwardDiff.jl for the forseeable future... (JuliaGaussianProcesses/KernelFunctions.jl#517). You said you already tried Enzyme. Do you have that code still lying around to test if that doesn't have this problem? (I.e. replace the part where FD is used and run tests - main has tests included which check whether this probem is fixed). EDIT: (from slack Autodiff) 17:29 does Enzyme Forward AD skip measure zero branches? (I heard that ForwardDiff tried that, broke some peoples code and reverted) JuliaGaussianProcesses/KernelFunctions.jl#517 |
I saw the discussion on the slack. Personally, I like that Enzyme at least get the first derivative of the Matern kernel right for I think the slack discussion indicated that it would be possible in principle to implement this in Enzyme, so especially once there is a strong use case (like a specific application where the second derivative of the Matern kernel is needed), it could happen reasonably quickly. Also, I apologize for slow response, I am on work travel this week, and on a wedding/ vacation next week, which also means very limited time to test or prototype. |
I mean if you want to simulate the first derivative of a GP I think the only reliable solution is to use a Taylor expansion around zero in the zero branch. I considered subtyping function diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
return apply_partial(k, px.indices, py.indices)(x, y)
end for the function diffKernelCall(k::Matern, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
if x == y
return # calculate the derivative in zero manually
end
return apply_partial(k, px.indices, py.indices)(x, y)
end but that is going to break for transformations of the Matern kernel. And since everything is a transformation essentially (lengthscales, variance) that is going to break immediately. Now a taylor expansion is computationally more expensive than simply returning one and only works for n orders (which is extremely annoying). So KernelFunctions is probably going to be reluctant about it due to the first reason and adding It might be possible to work around the performance problem by specializing on duals (but that doesn't solve the problem for backwardsdiff) and will be forwarddiff package dependent. And given that the RationalQuadratic also breaks, I forsee a ton of specialized code for all the kernels which I am not looking forward to at all. So big ugh.
Enjoy your travel! :) This is not going to go anywhere anyway it looks like ... |
In my experiments taking the derivative of |
the matern32 and mathern52 kernels fortunately do not have a special branching - but it appears they appear to return NaN in some cases: Sanity Checks with Matern 5/2 Kernel (metric = Distances.Euclidean(0.0)): Test Failed at /home/runner/work/DifferentiableKernelFunctions.jl/DifferentiableKernelFunctions.jl/test/diffKernel.jl:24
Expression: k((partial(1), x), (partial(1), x)) > 0
Evaluated: NaN > 0 https://github.com/FelixBenning/DifferentiableKernelFunctions.jl/actions/runs/5188477721 |
Indeed, I just checked this as well. This definitely needs to be fixed. |
btw - I swapped julia> productArray(partial.(1:2), rand(3)) |> vec
6-element reshape(::ProductArray{Tuple{Vector{Partial{1, Tuple{Int64}}}, Vector{Float64}}, Tuple{Partial{1, Tuple{Int64}}, Float64}, 2}, 6) with eltype Tuple{Partial{1, Tuple{Int64}}, Float64}:
(∂₁, 0.2198546152745091)
(∂₂, 0.2198546152745091)
(∂₁, 0.07434921369632463)
(∂₂, 0.07434921369632463)
(∂₁, 0.43710532325688556)
(∂₂, 0.43710532325688556) and we want one point with many derivatives grouped and not all the points for each derivative (at least usually). PermutedDimsArray could achieve the same result, but this is the better default I think. |
I am happy now, with the main interface. I can't decide what the best API for full derivatives are yet. There are not going to be many orders of differentiating, so I thought it makes sense to allow the compiler to specialize which is why I am using function fullderivative(::Val{order}, input_indices::AbstractVector{Int}) where {order}
return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
end
fullderivative(::Val{order}, dim::Integer) where {order} = fullderivative(Val{order}(), Base.OneTo(dim))
function fullderivative(::Val{order}, input_indices::AbstractArray{T,N}) where {order,N,T<:Base.AbstractCartesianIndex{N}}
return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
end
gradient(input_indices::AbstractArray) = fullderivative(Val(1), input_indices)
gradient(dim::Integer) = fullderivative(Val(1), dim)
hessian(input_indices::AbstractArray) = fullderivative(Val(2), input_indices)
hessian(dim::Integer) = fullderivative(Val(2), dim)
diffAt(::Val{order}, x) where {order} = productArray(_diffAt(Base.IteratorSize(x), Val(order), x), (x,))
_diffAt(::Base.HasLength, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
_diffAt(::Base.HasShape{1}, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
_diffAt(::Base.HasShape, ::Val{order}, x) where {order} = fullderivative(Val(order), CartesianIndices(axes(x)))
gradAt(x) = diffAt(Val(1), x)
grad(f) = x -> f.(gradAt(x)) # for f = rand(::GP), grad(f)(x) should work. |
I really like the idea of making differentiating part of the Multi-Output interface. With this in mind some thoughts I had on adding differentiation to KernelFunctions:
AbstractDifferentiation
is ready, I think this functionality should be implemented via a package extension. Here one could either pick one package (Enzyme or ForwardDiff), or write a common core for both, and then specific implementation of the differentiation for each package. I currently do not know what would happen if one were to load both.Again, very happy to make this reality.
The text was updated successfully, but these errors were encountered: