-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename TensorProduct
and implement TensorCore.tensor
#232
Conversation
@@ -0,0 +1,145 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only changed TensorProduct
to KernelTensorProduct
and added a new docstring. I renamed the file in a separate PR but somehow, in contrast to the tests, Github does not display the changes nicely.
@@ -1,5 +1,3 @@ | |||
hadamard(x, y) = x .* y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorCore contains an implementation of hadamard
.
x -> SqExponentialKernel() * LinearKernel(; c=x[1]), | ||
rand(1); | ||
ADs=[:ForwardDiff, :ReverseDiff, :Zygote], | ||
x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems AD errors are fixed by enforcing non-negativity of the constant offset. Same below for KernelSum
and KernelTensorProduct
.
@@ -0,0 +1,36 @@ | |||
@testset "overloads" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consistent with the current test setup, I created a separate file that tests only the implementations of +
, *
, and ⊗
.
@@ -1,7 +1,7 @@ | |||
using KernelFunctions | |||
using AxisArrays | |||
using Distances | |||
using Kronecker | |||
using Kronecker: Kronecker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again this avoids conflicts between Kronecker.:⊗
and TensorCore.:⊗
while still loading kernelkronmat
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also opened an issue: MichielStock/Kronecker.jl#80
@eval begin | ||
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2) | ||
|
||
$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default behavior is/was that if we have one kernel <:Tuple
and one kernel <:AbstractVector
, one will get a <:Tuple
. Is it really a sensible choice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nice thing about this default is that you don't end up with abstract types unexpectedly but, of course, it might not be friendly to the compiler. I haven't seen any instances of it in our examples and tests, so I am not sure how often such a combination happens in practice.
AD errors on Julia 1.3 are caused by |
Is there anything left that should be addressed? |
This is ready for another round of review, most (all?) unrelated changes are included in the master branch. |
This PR renames
TensorProduct
toKernelTensorProduct
(analogously toKernelSum
andKernelProduct
) and implements and exportsTensorCore.tensor === TensorCore.:⊗
for tensor product kernels (similar to the overloads forKernelSum
andKernelProduct
).IMO the name
TensorProduct
is too generic to be owned and exported by KernelFunctions. The overload is nice to have, I think, especially with the more verboseKernelTensorProduct
. TensorCore is a lightweight package that contains default implementations for tensor operations. Originally, the code was merged in JuliaLang/julia#35150 which was later reverted in JuliaLang/julia#35744 and moved to its own package. It seems other packages with tensor operations such as Kronecker.jl should also overloadTensorCore.tensor
or rename their functions if they are type piracy to avoid conflicts.