Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename TensorProduct and implement TensorCore.tensor #232

Merged
merged 13 commits into from
Jan 20, 2021
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.16"
version = "0.8.17"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -13,6 +13,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -24,5 +25,6 @@ Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"

[compat]
Documenter = "0.23, 0.24, 0.25, 0.26"
PDMats = "0.10"
9 changes: 8 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ using KernelFunctions
DocMeta.setdocmeta!(
KernelFunctions,
:DocTestSetup,
:(using KernelFunctions, LinearAlgebra, Random);
quote
using KernelFunctions
using LinearAlgebra
using Random
using PDMats: PDMats
end;
devmotion marked this conversation as resolved.
Show resolved Hide resolved
recursive=true,
)

Expand All @@ -23,6 +28,8 @@ makedocs(;
"Custom Kernels" => "create_kernel.md",
"API" => "api.md",
],
strict=true,
checkdocs=:exports,
devmotion marked this conversation as resolved.
Show resolved Hide resolved
)

deploydocs(;
Expand Down
13 changes: 11 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ EyeKernel
ZeroKernel
CosineKernel
SqExponentialKernel
GaussianKernel
RBFKernel
SEKernel
ExponentialKernel
LaplacianKernel
GammaExponentialKernel
ExponentiatedKernel
FBMKernel
GaborKernel
MaternKernel
Matern12Kernel
Matern32Kernel
Matern52Kernel
NeuralNetworkKernel
Expand All @@ -41,6 +46,9 @@ spectral_mixture_kernel
spectral_mixture_product_kernel
PeriodicKernel
WienerKernel
MOKernel
IndependentMOKernel
LatentFactorMOKernel
```

## Composite Kernels
Expand All @@ -50,7 +58,7 @@ TransformedKernel
ScaledKernel
KernelSum
KernelProduct
TensorProduct
KernelTensorProduct
```

## Transforms
Expand All @@ -64,6 +72,7 @@ LinearTransform
FunctionTransform
SelectTransform
ChainTransform
PeriodicTransform
```

## Functions
Expand All @@ -74,7 +83,6 @@ kernelmatrix!
kerneldiagmatrix
kerneldiagmatrix!
kernelpdmat
kernelkronmat
devmotion marked this conversation as resolved.
Show resolved Hide resolved
nystrom
transform
```
Expand All @@ -84,6 +92,7 @@ transform
```@docs
ColVecs
RowVecs
MOInput
NystromFact
```

Expand Down
2 changes: 1 addition & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ The [`KernelProduct`](@ref) is defined as a product of kernels:

### Tensor Product

The [`TensorProduct`](@ref) is defined as:
The [`KernelTensorProduct`](@ref) is defined as a tensor product of kernels:

```math
k(x,x';\{k_i\}) = \prod_i k_i(x_i,x'_i)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ To evaluate the kernel function on two vectors you simply call the kernel object
## Creating a kernel matrix

Kernel matrices can be created via the `kernelmatrix` function or `kerneldiagmatrix` for only the diagonal.
An important argument to give is the data layout of the input `obsdim`. It specifies whether the number of observed data points is along the first dimension (`obsdim=1`, i.e. the matrix shape is number of samples times number of features) or along the second dimension (`obsdim=2`, i.e. the matrix shape is number of features times number of samples), similarly to [Distances.jl](https://github.com/JuliaStats/Distances.jl). If not given explicitly, `obsdim` defaults to [`defaultobs`](@ref).
An important argument to give is the data layout of the input `obsdim`. It specifies whether the number of observed data points is along the first dimension (`obsdim=1`, i.e. the matrix shape is number of samples times number of features) or along the second dimension (`obsdim=2`, i.e. the matrix shape is number of features times number of samples), similarly to [Distances.jl](https://github.com/JuliaStats/Distances.jl). If not given explicitly, `obsdim` defaults to `2`.
devmotion marked this conversation as resolved.
Show resolved Hide resolved
theogf marked this conversation as resolved.
Show resolved Hide resolved
For example:
```julia
k = SqExponentialKernel()
Expand Down
10 changes: 7 additions & 3 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ export LinearKernel, PolynomialKernel
export RationalQuadraticKernel, GammaRationalQuadraticKernel
export GaborKernel, PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel
export TensorProduct

export Transform,
SelectTransform,
Expand All @@ -56,6 +55,9 @@ export ColVecs, RowVecs
export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor, ⊗

using Compat
using Requires
using Distances, LinearAlgebra
Expand All @@ -65,6 +67,7 @@ using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase
using TensorCore

abstract type Kernel end
abstract type SimpleKernel <: Kernel end
Expand Down Expand Up @@ -104,7 +107,8 @@ include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("matrix", "kernelmatrix.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "tensorproduct.jl"))
include(joinpath("kernels", "kerneltensorproduct.jl"))
include(joinpath("kernels", "overloads.jl"))
include(joinpath("approximations", "nystrom.jl"))
include("generic.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function spectral_mixture_product_kernel(
if !(size(αs) == size(γs) == size(ωs))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
end
return TensorProduct(
return KernelTensorProduct(
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
)
Expand Down
5 changes: 5 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
)

@deprecate TensorProduct(kernels) KernelTensorProduct(kernels)
@deprecate TensorProduct(kernel::Kernel, kernels::Kernel...) KernelTensorProduct(
kernel, kernels...
)
25 changes: 0 additions & 25 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,6 @@ end

@functor KernelProduct

Base.:*(k1::Kernel, k2::Kernel) = KernelProduct(k1, k2)

function Base.:*(
k1::KernelProduct{<:AbstractVector{<:Kernel}},
k2::KernelProduct{<:AbstractVector{<:Kernel}},
)
return KernelProduct(vcat(k1.kernels, k2.kernels))
end

function Base.:*(k1::KernelProduct, k2::KernelProduct)
return KernelProduct(k1.kernels..., k2.kernels...)
end

function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}})
return KernelProduct(vcat(k, ks.kernels))
end

Base.:*(k::Kernel, kp::KernelProduct) = KernelProduct(k, kp.kernels...)

function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelProduct(vcat(ks.kernels, k))
end

Base.:*(kp::KernelProduct, k::Kernel) = KernelProduct(kp.kernels..., k)

Base.length(k::KernelProduct) = length(k.kernels)

(κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
Expand Down
22 changes: 0 additions & 22 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,6 @@ end

@functor KernelSum

Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)

function Base.:+(
k1::KernelSum{<:AbstractVector{<:Kernel}}, k2::KernelSum{<:AbstractVector{<:Kernel}}
)
return KernelSum(vcat(k1.kernels, k2.kernels))
end

Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...)

function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}})
return KernelSum(vcat(k, ks.kernels))
end

Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...)

function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelSum(vcat(ks.kernels, k))
end

Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k)

Base.length(k::KernelSum) = length(k.kernels)

(κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)
Expand Down
145 changes: 145 additions & 0 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only changed TensorProduct to KernelTensorProduct and added a new docstring. I renamed the file in a separate PR but somehow, in contrast to the tests, Github does not display the changes nicely.

KernelTensorProduct <: Kernel

Tensor product of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
product of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor`
operator or its alias `⊗` (can be typed by `\\otimes<tab>`).
```jldoctest tensorproduct
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelTensorProduct` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorProduct` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorproduct
julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2
true

julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2
true

julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2
true
```
"""
struct KernelTensorProduct{K} <: Kernel
kernels::K
end

function KernelTensorProduct(kernel::Kernel, kernels::Kernel...)
return KernelTensorProduct((kernel, kernels...))
end

@functor KernelTensorProduct

Base.length(kernel::KernelTensorProduct) = length(kernel.kernels)

function (kernel::KernelTensorProduct)(x, y)
if !(length(x) == length(y) == length(kernel))
throw(DimensionMismatch("number of kernels and number of features
are not consistent"))
end
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelTensorProduct, x::AbstractVector)
return dim(x) == length(k) ||
error("number of kernels and groups of features are not consistent")
end

# Utility for slicing up inputs.
slices(x::AbstractVector{<:Real}) = (x,)
slices(x::ColVecs) = eachrow(x.X)
slices(x::RowVecs) = eachcol(x.X)

function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi)
end

return K
end

function kernelmatrix!(
K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi, yi)
end

return K
end

function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kerneldiagmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kerneldiagmatrix(k, xi)
end

return K
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
end

function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
end

Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)

function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int)
print(io, "Tensor product of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
print(io, "- ")
printshifted(io, k, shift + 2)
end
end
Loading