Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Next steps for differentiating #3

Open
Crown421 opened this issue Jun 1, 2023 · 16 comments
Open

Next steps for differentiating #3

Crown421 opened this issue Jun 1, 2023 · 16 comments

Comments

@Crown421
Copy link
Collaborator

Crown421 commented Jun 1, 2023

I really like the idea of making differentiating part of the Multi-Output interface. With this in mind some thoughts I had on adding differentiation to KernelFunctions:

  • I think until AbstractDifferentiation is ready, I think this functionality should be implemented via a package extension. Here one could either pick one package (Enzyme or ForwardDiff), or write a common core for both, and then specific implementation of the differentiation for each package. I currently do not know what would happen if one were to load both.
  • It is often the case that autodiff functions can return both the function value and the derivative. I think it would be very cool to have a way to do so that something `DiffKernel(kernel)([x1, (x1, dx1)], x2]) only results in a single call. Not quite sure how to do that yet though.

Again, very happy to make this reality.

@FelixBenning
Copy link
Owner

It is often the case that autodiff functions can return both the function value and the derivative

This is a performance consideration similar to how backwardsdiff should be more performant in N->1 maps. Since we have partial derivatives right now 1->1 maps are fastest with forwarddiff, but ideally you would somehow use backwarddiff if you were to know that you want to simulate the entire gradient.

I believe that both cases will eventually be solvable by specialization on broadcasting. cf. https://github.com/FelixBenning/DifferentiableKernelFunctions.jl/blob/main/src/partial.jl#L57-L65

gradient(dim::Integer) = mappedarray(partial, Base.OneTo(dim))
hessian(dim::Integer) = mappedarray(partial, lazy_product(Base.OneTo(dim), Base.OneTo(dim)))
fullderivative(order::Integer,dim::Integer) = mappedarray(partial, lazy_product(ntuple(_->Base.OneTo(dim), order)...))

# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Int},1,T,typeof(partial)}
function extract_range(p_map::MappedPartialVec{T}) where {T<:AbstractUnitRange{Int}}
    return p_map.data::T
end

you would extract the original range from the lazy partial array and therefore know that you could use backwarddiff. If you similarly have a structure which says calculate all points and derivatives you could also use the "value+derivative" functionality.

Tomorrow I'll think about the all derivatives up to order n range. Specializing on this structure when broadcasting should allow for efficient computation later on

@FelixBenning
Copy link
Owner

FelixBenning commented Jun 1, 2023

julia> DKF
DifferentiableKernelFunctions

julia> DKF.gradient(10)
10-element mappedarray(DifferentiableKernelFunctions.partial, ::Base.OneTo{Int64}) with eltype DifferentiableKernelFunctions.Partial{1, Int64}:
 ∂₁: Partial(1)
 ∂₂: Partial(2)
 ∂₃: Partial(3)
 ∂₄: Partial(4)
 ∂₅: Partial(5)
 ∂₆: Partial(6)
 ∂₇: Partial(7)
 ∂₈: Partial(8)
 ∂₉: Partial(9)
 ∂₁₀: Partial(10)

julia> dump(ans)
MappedArrays.ReadonlyMappedArray{DifferentiableKernelFunctions.Partial{1, Int64}, 1, Base.OneTo{Int64}, typeof(partial)}
  f: partial (function of type typeof(partial))
  data: Base.OneTo{Int64}
    stop: Int64 10

this is also memory efficient

@FelixBenning
Copy link
Owner

I think until AbstractDifferentiation is ready, I think this functionality should be implemented via a package extension. Here one could either pick one package (Enzyme or ForwardDiff), or write a common core for both, and then specific implementation of the differentiation for each package. I currently do not know what would happen if one were to load both.

honestly I would suggest we test which autodiff package works best and then just prescribe one until abstractdifferentiation is ready. I mean if we already did the "what works best" work, why should a user have to do it again? Sane defaults are a feature in my opinion. But an escape hatch is a nice feature and it would be good to keep that in mind for the architecture.

@Crown421
Copy link
Collaborator Author

Crown421 commented Jun 1, 2023

honestly I would suggest we test which autodiff package works best and then just prescribe one until abstractdifferentiation is ready.

This also makes sense to me. At the very least to start with, not point in trying to cover too much too soon.

@FelixBenning
Copy link
Owner

@Crown421 do you think it is worth thinking about input shapes other than vecs? Up until now I tried to do that but for gradients, hessians and its tuples this is becoming a headache

@Crown421
Copy link
Collaborator Author

Crown421 commented Jun 2, 2023

It is difficult. I personally only work with vectors, but I know some people use a variety of inputs (including strings, interestingly).
I think working out the interface is difficult enough just for vectors, so it makes sense to me to stick to that for now.

@FelixBenning
Copy link
Owner

I know some people use a variety of inputs (including strings, interestingly).

I mean here we only care about the indices and strings are actually linearly indexed. But partial itself should now be robust to Cartesian indices #8. So we should be future proof there. Now to the combination of (id, gradient, hessian). It is not an array like structure, since every entry has a different shape. The only way this makes sense if flattened.

Now

Iterators.flatten(id, gradient, hessian)

would pretty much do what we want. But it isn't an AbstractArray. So I'll probably do the same as I did with ProductArrays and write a flatten function which behaves just like Iterators.flatten but behaves like an AbstractVector. Then the type hirarchy would be

FlattenedArray
-> MappedArray(partial, productArray(Tuple{}))
-> MappedArray(partial, productArray(input_indices))
-> MappedArray(partial, productArray(input_indices, input_indices))
... 

you could probably write a specialization of flatten which permutes a flattening of MappedArrays with identical map functions out into

MappedArray
- partial
- FlattenedArray
---> productArray(Tuple{})
---> productArray(input_indices)
...

these derivativeRequests abstract Array would then typically be producted with an actual point vector v, i.e.

x = productArray(v, derivativeRequests)

So we would want to specialize on

const DiffPoints{T} = ProductArray{Tuple{T, MappedArray{partial, ?}}

@FelixBenning
Copy link
Owner

julia> Iterators.flatten((Ref(partial()), partial.(1:10), DKF.hessian(10))) |> collect
111-element Vector{DifferentiableKernelFunctions.Partial}:
 id: Partial() a zero order derivative
 ∂₁: Partial(1)
 ∂₂: Partial(2)
 ∂₃: Partial(3)
 ∂₄: Partial(4)
 ∂₅: Partial(5)
 ∂₆: Partial(6)
 ∂₇: Partial(7)
 ∂₈: Partial(8)
 ∂₉: Partial(9)
 ∂₁₀: Partial(10)
 ∂₁∂₁: Partial(1,1)
 ∂₂∂₁: Partial(2,1)
 ∂₃∂₁: Partial(3,1)
 ∂₄∂₁: Partial(4,1)
 ∂₅∂₁: Partial(5,1)
 ∂₆∂₁: Partial(6,1)
 ∂₇∂₁: Partial(7,1)
 ∂₈∂₁: Partial(8,1)
 ∂₉∂₁: Partial(9,1)
 ∂₁₀∂₁: Partial(10,1)
 ∂₁∂₂: Partial(1,2)
 ∂₂∂₂: Partial(2,2)
 
 ∂₈∂₈: Partial(8,8)
 ∂₉∂₈: Partial(9,8)
 ∂₁₀∂₈: Partial(10,8)
 ∂₁∂₉: Partial(1,9)
 ∂₂∂₉: Partial(2,9)
 ∂₃∂₉: Partial(3,9)
 ∂₄∂₉: Partial(4,9)
 ∂₅∂₉: Partial(5,9)
 ∂₆∂₉: Partial(6,9)
 ∂₇∂₉: Partial(7,9)
 ∂₈∂₉: Partial(8,9)
 ∂₉∂₉: Partial(9,9)
 ∂₁₀∂₉: Partial(10,9)
 ∂₁∂₁₀: Partial(1,10)
 ∂₂∂₁₀: Partial(2,10)
 ∂₃∂₁₀: Partial(3,10)
 ∂₄∂₁₀: Partial(4,10)
 ∂₅∂₁₀: Partial(5,10)
 ∂₆∂₁₀: Partial(6,10)
 ∂₇∂₁₀: Partial(7,10)
 ∂₈∂₁₀: Partial(8,10)
 ∂₉∂₁₀: Partial(9,10)
 ∂₁₀∂₁₀: Partial(10,10)

aaaand I just realized that Schwarz theorem should probably be considered ugh

@FelixBenning
Copy link
Owner

FelixBenning commented Jun 5, 2023

honestly I would suggest we test which autodiff package works best and then just prescribe one until abstractdifferentiation is ready.

This also makes sense to me. At the very least to start with, not point in trying to cover too much too soon. @Crown421

Aaaaand we have just ruled out ForwardDiff.jl for the forseeable future... (JuliaGaussianProcesses/KernelFunctions.jl#517).

You said you already tried Enzyme. Do you have that code still lying around to test if that doesn't have this problem? (I.e. replace the part where FD is used and run tests - main has tests included which check whether this probem is fixed).

EDIT: (from slack Autodiff)

17:29 does Enzyme Forward AD skip measure zero branches? (I heard that ForwardDiff tried that, broke some peoples code and reverted) JuliaGaussianProcesses/KernelFunctions.jl#517
17:31 No it does not

@Crown421
Copy link
Collaborator Author

Crown421 commented Jun 6, 2023

I saw the discussion on the slack. Personally, I like that Enzyme at least get the first derivative of the Matern kernel right for x=y, Zygote does not do this (so that is definitely out).

I think the slack discussion indicated that it would be possible in principle to implement this in Enzyme, so especially once there is a strong use case (like a specific application where the second derivative of the Matern kernel is needed), it could happen reasonably quickly.

Also, I apologize for slow response, I am on work travel this week, and on a wedding/ vacation next week, which also means very limited time to test or prototype.

@FelixBenning
Copy link
Owner

I mean if you want to simulate the first derivative of a GP $Z$, then $\text{Cov}(Z'(x), Z'(y)) = k''(x-y)$ so the second derivative of the matern kernel is already needed for the first derivative of the GP.

I think the only reliable solution is to use a Taylor expansion around zero in the zero branch. I considered subtyping

function diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
    return apply_partial(k, px.indices, py.indices)(x, y)
end

for the MaternKernel i.e.

function diffKernelCall(k::Matern, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
    if x == y
            return # calculate the derivative in zero manually
    end
    return apply_partial(k, px.indices, py.indices)(x, y)
end

but that is going to break for transformations of the Matern kernel. And since everything is a transformation essentially (lengthscales, variance) that is going to break immediately.

Now a taylor expansion is computationally more expensive than simply returning one and only works for n orders (which is extremely annoying). So KernelFunctions is probably going to be reluctant about it due to the first reason and adding $2\nu$ ordered taylor polynomials to ensure we got everything covered is meh anyway...

It might be possible to work around the performance problem by specializing on duals (but that doesn't solve the problem for backwardsdiff) and will be forwarddiff package dependent.

And given that the RationalQuadratic also breaks, I forsee a ton of specialized code for all the kernels which I am not looking forward to at all.

So big ugh.

Also, I apologize for slow response, I am on work travel this week, and on a wedding/ vacation next week, which also means very limited time to test or prototype.

Enjoy your travel! :) This is not going to go anywhere anyway it looks like ...

@Crown421
Copy link
Collaborator Author

Crown421 commented Jun 6, 2023

In my experiments taking the derivative of Matern32 with respect to both components worked fine. In fact, this worked fine in general (but I did not go through every possible kernel).

@FelixBenning
Copy link
Owner

the matern32 and mathern52 kernels fortunately do not have a special branching - but it appears they appear to return NaN in some cases:

 Sanity Checks with Matern 5/2 Kernel (metric = Distances.Euclidean(0.0)): Test Failed at /home/runner/work/DifferentiableKernelFunctions.jl/DifferentiableKernelFunctions.jl/test/diffKernel.jl:24
  Expression: k((partial(1), x), (partial(1), x)) > 0
   Evaluated: NaN > 0

https://github.com/FelixBenning/DifferentiableKernelFunctions.jl/actions/runs/5188477721

@Crown421
Copy link
Collaborator Author

Crown421 commented Jun 6, 2023

the matern32 and mathern52 kernels fortunately do not have a special branching - but it appears they appear to return NaN in some cases:

 Sanity Checks with Matern 5/2 Kernel (metric = Distances.Euclidean(0.0)): Test Failed at /home/runner/work/DifferentiableKernelFunctions.jl/DifferentiableKernelFunctions.jl/test/diffKernel.jl:24

Indeed, I just checked this as well. This definitely needs to be fixed.

@FelixBenning
Copy link
Owner

FelixBenning commented Jun 6, 2023

btw - I swapped (x, partial) to (partial, x), because this makes it easier to group it with (partial, x, idx) for multioutput kernels and it makes more sense default grouping wise. I.e. due to julias column majorness the first entry changes first:

julia> productArray(partial.(1:2), rand(3)) |> vec
6-element reshape(::ProductArray{Tuple{Vector{Partial{1, Tuple{Int64}}}, Vector{Float64}}, Tuple{Partial{1, Tuple{Int64}}, Float64}, 2}, 6) with eltype Tuple{Partial{1, Tuple{Int64}}, Float64}:
 (∂₁, 0.2198546152745091)
 (∂₂, 0.2198546152745091)
 (∂₁, 0.07434921369632463)
 (∂₂, 0.07434921369632463)
 (∂₁, 0.43710532325688556)
 (∂₂, 0.43710532325688556)

and we want one point with many derivatives grouped and not all the points for each derivative (at least usually). PermutedDimsArray could achieve the same result, but this is the better default I think.

@FelixBenning
Copy link
Owner

FelixBenning commented Jun 6, 2023

I am happy now, with the main interface. I can't decide what the best API for full derivatives are yet. There are not going to be many orders of differentiating, so I thought it makes sense to allow the compiler to specialize which is why I am using Val{order}. I liked the DerivativeAt type of way to generate the vectors (because you don't have to manually provide the input dimensions as they are included as axes(x) in x). But I don't like diffAt very much as a name and I also can not come up with good names for "all derivatives up to order ord" (which would probably have to be a lazily flattened array of fullderivative arrays).

function fullderivative(::Val{order}, input_indices::AbstractVector{Int}) where {order}
    return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
end
fullderivative(::Val{order}, dim::Integer) where {order} = fullderivative(Val{order}(), Base.OneTo(dim))
function fullderivative(::Val{order}, input_indices::AbstractArray{T,N}) where {order,N,T<:Base.AbstractCartesianIndex{N}}
    return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
end

gradient(input_indices::AbstractArray) = fullderivative(Val(1), input_indices)
gradient(dim::Integer) = fullderivative(Val(1), dim)

hessian(input_indices::AbstractArray) = fullderivative(Val(2), input_indices)
hessian(dim::Integer) = fullderivative(Val(2), dim)

diffAt(::Val{order}, x) where {order} = productArray(_diffAt(Base.IteratorSize(x), Val(order), x), (x,))
_diffAt(::Base.HasLength, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
_diffAt(::Base.HasShape{1}, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
_diffAt(::Base.HasShape, ::Val{order}, x) where {order} = fullderivative(Val(order), CartesianIndices(axes(x)))

gradAt(x) = diffAt(Val(1), x)
grad(f) = x -> f.(gradAt(x)) # for f = rand(::GP),  grad(f)(x) should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants