Skip to content

Commit

Permalink
[NDTensors] Start TensorAlgebra module, new TTGT implementation (#1265
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mtfishman authored Nov 17, 2023
1 parent 408516d commit 3f1afb8
Show file tree
Hide file tree
Showing 24 changed files with 393 additions and 2 deletions.
2 changes: 2 additions & 0 deletions NDTensors/src/BlockSparseArrays/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
2 changes: 2 additions & 0 deletions NDTensors/src/DiagonalArrays/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
8 changes: 6 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module NDTensors
# TODO: List types, macros, and functions being used.
using Adapt
using Base.Threads
using Compat
Expand All @@ -19,9 +20,12 @@ using TimerOutputs
using TupleTools

# TODO: Define an `AlgorithmSelection` module
# TODO: List types, macros, and functions being used.
include("algorithm.jl")
include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("TensorAlgebra/src/TensorAlgebra.jl")
using .TensorAlgebra: TensorAlgebra
include("DiagonalArrays/src/DiagonalArrays.jl")
using .DiagonalArrays
include("BlockSparseArrays/src/BlockSparseArrays.jl")
Expand Down Expand Up @@ -76,8 +80,8 @@ include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
include("tensoroperations/generic_tensor_operations.jl")
include("tensoroperations/contraction_logic.jl")
include("abstractarray/tensoralgebra/contract.jl")

#####################################
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/SetParameters/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
2 changes: 2 additions & 0 deletions NDTensors/src/SmallVectors/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
2 changes: 2 additions & 0 deletions NDTensors/src/SortedSets/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
2 changes: 2 additions & 0 deletions NDTensors/src/TagSets/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
2 changes: 2 additions & 0 deletions NDTensors/src/TensorAlgebra/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "blue"
indent = 2
11 changes: 11 additions & 0 deletions NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module TensorAlgebra
using LinearAlgebra: mul!
using ..NDTensors: Algorithm, @Algorithm_str

include("bipartitionedpermutation.jl")
include("fusedims.jl")
include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")
end
17 changes: 17 additions & 0 deletions NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct BipartitionedPermutation{P1,P2}
partition1::P1
partition2::P2
end

function Base.getindex(biperm::BipartitionedPermutation, i)
if i == 1
return biperm.partition1
elseif i == 2
return biperm.partition2
end
return error("Only 2 partitions")
end

function flatten(biperm::BipartitionedPermutation)
return (biperm[1]..., biperm[2]...)
end
49 changes: 49 additions & 0 deletions NDTensors/src/TensorAlgebra/src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
function allocate_output(
::typeof(contract),
alg::Algorithm,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β,
)
axes_dest = output_axes(contract, alg, labels_dest, axes(a1), labels1, axes(a2), labels2)
# TODO: Define `output_type(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
# TODO: Define `output_structure(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
# TODO: Define `allocate(type, structure)`.
return Array{promote_type(eltype(a1), eltype(a2))}(undef, length.(axes_dest))
end

# TODO: Generalize to `output_structure`.
function output_axes(
f::typeof(contract), alg::Algorithm, labels_dest, axes1, labels1, axes2, labels2
)
biperm_dest, biperm1, biperm2 = bipartitioned_permutations(
f, labels_dest, labels1, labels2
)
return output_axes(f, alg, biperm_dest, axes1, biperm1, axes2, biperm2)
end

# TODO: Generalize to `output_structure`.
function output_axes(
f::typeof(contract),
alg::Algorithm,
biperm_dest::BipartitionedPermutation,
axes1,
biperm1::BipartitionedPermutation,
axes2,
biperm2::BipartitionedPermutation,
)
perm_dest = flatten(biperm_dest)
nuncontracted1 = length(biperm1[1])
axes_dest = map(perm_dest) do i
return if i <= nuncontracted1
axes1[biperm1[1][i]]
else
axes2[biperm2[2][i - nuncontracted1]]
end
end
return axes_dest
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct BipartitionedPermutation{P1,P2}
partition1::P1
partition2::P2
end

function Base.getindex(biperm::BipartitionedPermutation, i)
if i == 1
return biperm.partition1
elseif i == 2
return biperm.partition2
end
return error("Only 2 partitions")
end

function flatten(biperm::BipartitionedPermutation)
return (biperm[1]..., biperm[2]...)
end
153 changes: 153 additions & 0 deletions NDTensors/src/TensorAlgebra/src/contract/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# TODO: Add `contract!!` definitions as pass-throughs to `contract!`.

default_contract_alg() = Algorithm"matricize"()

function contract(a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...)
return contract(a1, labels1, a2, labels2, true, false; kwargs...)
end

function contract(
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β;
alg=default_contract_alg(),
kwargs...,
)
return contract(Algorithm(alg), a1, labels1, a2, labels2, α, β; kwargs...)
end

function contract(
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
)
return contract(alg, a1, labels1, a2, labels2, true, false; kwargs...)
end

function contract(
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β; kwargs...
)
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α, β; kwargs...)
return contract(alg, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...), labels_dest
end

function contract(
labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
)
return contract(
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
true,
false;
kwargs...,
)
end

function contract(
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β;
alg=default_contract_alg(),
kwargs...,
)
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
end

function contract(
alg::Algorithm,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2;
kwargs...,
)
return contract(alg, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
end

function contract(
alg::Algorithm,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β;
kwargs...,
)
biperm_dest, biperm1, biperm2 = bipartitioned_permutations(
contract, alg, labels_dest, labels1, labels2
)
a_dest = allocate_output(
contract, alg, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...
)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
return a_dest
end

function contract!(
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2;
kwargs...,
)
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
return a_dest
end

function contract!(
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β;
alg=default_contract_alg(),
kwargs...,
)
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
return a_dest
end

function contract!(
alg::Algorithm,
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2;
kwargs...,
)
contract!(alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
return a_dest
end

function contract!(
alg::Algorithm,
a_dest::AbstractArray,
labels_dest,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β;
kwargs...,
)
return error("Not implemented")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function contract!(
alg::Algorithm"matricize",
a_dest::AbstractArray,
biperm_dest::BipartitionedPermutation,
a1::AbstractArray,
biperm1::BipartitionedPermutation,
a2::AbstractArray,
biperm2::BipartitionedPermutation,
α,
β,
)
a_dest_matricized = matricize(a_dest, biperm_dest)
a1_matricized = matricize(a1, biperm1)
a2_matricized = matricize(a2, biperm2)
mul!(a_dest_matricized, a1_matricized, a2_matricized, α, β)
perm_dest = flatten(biperm_dest)
# TODO: Create a function `unmatricize` or `unfusedims`.
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest)
a_dest_copy = reshape(a_dest_matricized, axes(a_dest))
permutedims!(a_dest, a_dest_copy, perm_dest)
return a_dest
end
48 changes: 48 additions & 0 deletions NDTensors/src/TensorAlgebra/src/contract/output_labels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function output_labels(
f::typeof(contract),
alg::Algorithm,
a1::AbstractArray,
labels1,
a2::AbstractArray,
labels2,
α,
β,
)
return output_labels(f, alg, labels1, labels2)
end

function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2)
return output_labels(f, labels1, labels2)
end

function output_labels(::typeof(contract), labels1, labels2)
return symdiff(labels1, labels2)
end

function bipartitioned_permutations(
f::typeof(contract), alg::Algorithm, labels_dest, labels1, labels2
)
return bipartitioned_permutations(f, labels_dest, labels1, labels2)
end

function bipartitioned_permutations(::typeof(contract), labels_dest, labels1, labels2)
labels12 = (labels1..., labels2...)
if isodd(length(labels12) - length(labels_dest))
error("Can't contract $labels1 and $labels2 to $labels_dest")
end
labels_contracted = unique(setdiff(labels12, labels_dest))
labels1_uncontracted = setdiff(labels1, labels_contracted)
labels2_uncontracted = setdiff(labels2, labels_contracted)
# Positions of labels.
pos_dest_1 = map(l -> findfirst(isequal(l), labels_dest), labels1_uncontracted)
pos_dest_2 = map(l -> findfirst(isequal(l), labels_dest), labels2_uncontracted)
pos1_contracted = map(l -> findfirst(isequal(l), labels1), labels_contracted)
pos2_contracted = map(l -> findfirst(isequal(l), labels2), labels_contracted)
pos1_uncontracted = map(l -> findfirst(isequal(l), labels1), labels1_uncontracted)
pos2_uncontracted = map(l -> findfirst(isequal(l), labels2), labels2_uncontracted)
# Bipartitioned permutations.
biperm_dest = BipartitionedPermutation(pos_dest_1, pos_dest_2)
biperm1 = BipartitionedPermutation(pos1_uncontracted, pos1_contracted)
biperm2 = BipartitionedPermutation(pos2_contracted, pos2_uncontracted)
return biperm_dest, biperm1, biperm2
end
17 changes: 17 additions & 0 deletions NDTensors/src/TensorAlgebra/src/fusedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
fuse(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) * length(a2))
fuse(a...) = foldl(fuse, a)

matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...))

function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
# Permute and fuse the axes
axes_src = axes(a)
axes_codomain = map(i -> axes_src[i], biperm[1])
axes_domain = map(i -> axes_src[i], biperm[2])
axis_codomain_fused = fuse(axes_codomain...)
axis_domain_fused = fuse(axes_domain...)
# Permute the array
perm = flatten(biperm)
a_permuted = permutedims(a, perm)
return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused))
end
3 changes: 3 additions & 0 deletions NDTensors/src/TensorAlgebra/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Loading

0 comments on commit 3f1afb8

Please sign in to comment.