-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors] Start
TensorAlgebra
module, new TTGT implementation (#1265
- Loading branch information
Showing
24 changed files
with
393 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
module TensorAlgebra | ||
using LinearAlgebra: mul! | ||
using ..NDTensors: Algorithm, @Algorithm_str | ||
|
||
include("bipartitionedpermutation.jl") | ||
include("fusedims.jl") | ||
include("contract/contract.jl") | ||
include("contract/output_labels.jl") | ||
include("contract/allocate_output.jl") | ||
include("contract/contract_matricize/contract.jl") | ||
end |
17 changes: 17 additions & 0 deletions
17
NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
struct BipartitionedPermutation{P1,P2} | ||
partition1::P1 | ||
partition2::P2 | ||
end | ||
|
||
function Base.getindex(biperm::BipartitionedPermutation, i) | ||
if i == 1 | ||
return biperm.partition1 | ||
elseif i == 2 | ||
return biperm.partition2 | ||
end | ||
return error("Only 2 partitions") | ||
end | ||
|
||
function flatten(biperm::BipartitionedPermutation) | ||
return (biperm[1]..., biperm[2]...) | ||
end |
49 changes: 49 additions & 0 deletions
49
NDTensors/src/TensorAlgebra/src/contract/allocate_output.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
function allocate_output( | ||
::typeof(contract), | ||
alg::Algorithm, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β, | ||
) | ||
axes_dest = output_axes(contract, alg, labels_dest, axes(a1), labels1, axes(a2), labels2) | ||
# TODO: Define `output_type(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`. | ||
# TODO: Define `output_structure(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`. | ||
# TODO: Define `allocate(type, structure)`. | ||
return Array{promote_type(eltype(a1), eltype(a2))}(undef, length.(axes_dest)) | ||
end | ||
|
||
# TODO: Generalize to `output_structure`. | ||
function output_axes( | ||
f::typeof(contract), alg::Algorithm, labels_dest, axes1, labels1, axes2, labels2 | ||
) | ||
biperm_dest, biperm1, biperm2 = bipartitioned_permutations( | ||
f, labels_dest, labels1, labels2 | ||
) | ||
return output_axes(f, alg, biperm_dest, axes1, biperm1, axes2, biperm2) | ||
end | ||
|
||
# TODO: Generalize to `output_structure`. | ||
function output_axes( | ||
f::typeof(contract), | ||
alg::Algorithm, | ||
biperm_dest::BipartitionedPermutation, | ||
axes1, | ||
biperm1::BipartitionedPermutation, | ||
axes2, | ||
biperm2::BipartitionedPermutation, | ||
) | ||
perm_dest = flatten(biperm_dest) | ||
nuncontracted1 = length(biperm1[1]) | ||
axes_dest = map(perm_dest) do i | ||
return if i <= nuncontracted1 | ||
axes1[biperm1[1][i]] | ||
else | ||
axes2[biperm2[2][i - nuncontracted1]] | ||
end | ||
end | ||
return axes_dest | ||
end |
17 changes: 17 additions & 0 deletions
17
NDTensors/src/TensorAlgebra/src/contract/bipartitionedpermutations.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
struct BipartitionedPermutation{P1,P2} | ||
partition1::P1 | ||
partition2::P2 | ||
end | ||
|
||
function Base.getindex(biperm::BipartitionedPermutation, i) | ||
if i == 1 | ||
return biperm.partition1 | ||
elseif i == 2 | ||
return biperm.partition2 | ||
end | ||
return error("Only 2 partitions") | ||
end | ||
|
||
function flatten(biperm::BipartitionedPermutation) | ||
return (biperm[1]..., biperm[2]...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# TODO: Add `contract!!` definitions as pass-throughs to `contract!`. | ||
|
||
default_contract_alg() = Algorithm"matricize"() | ||
|
||
function contract(a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...) | ||
return contract(a1, labels1, a2, labels2, true, false; kwargs...) | ||
end | ||
|
||
function contract( | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β; | ||
alg=default_contract_alg(), | ||
kwargs..., | ||
) | ||
return contract(Algorithm(alg), a1, labels1, a2, labels2, α, β; kwargs...) | ||
end | ||
|
||
function contract( | ||
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... | ||
) | ||
return contract(alg, a1, labels1, a2, labels2, true, false; kwargs...) | ||
end | ||
|
||
function contract( | ||
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β; kwargs... | ||
) | ||
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α, β; kwargs...) | ||
return contract(alg, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...), labels_dest | ||
end | ||
|
||
function contract( | ||
labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... | ||
) | ||
return contract( | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
true, | ||
false; | ||
kwargs..., | ||
) | ||
end | ||
|
||
function contract( | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β; | ||
alg=default_contract_alg(), | ||
kwargs..., | ||
) | ||
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α, β; kwargs...) | ||
end | ||
|
||
function contract( | ||
alg::Algorithm, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2; | ||
kwargs..., | ||
) | ||
return contract(alg, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) | ||
end | ||
|
||
function contract( | ||
alg::Algorithm, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β; | ||
kwargs..., | ||
) | ||
biperm_dest, biperm1, biperm2 = bipartitioned_permutations( | ||
contract, alg, labels_dest, labels1, labels2 | ||
) | ||
a_dest = allocate_output( | ||
contract, alg, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs... | ||
) | ||
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...) | ||
return a_dest | ||
end | ||
|
||
function contract!( | ||
a_dest::AbstractArray, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2; | ||
kwargs..., | ||
) | ||
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) | ||
return a_dest | ||
end | ||
|
||
function contract!( | ||
a_dest::AbstractArray, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β; | ||
alg=default_contract_alg(), | ||
kwargs..., | ||
) | ||
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...) | ||
return a_dest | ||
end | ||
|
||
function contract!( | ||
alg::Algorithm, | ||
a_dest::AbstractArray, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2; | ||
kwargs..., | ||
) | ||
contract!(alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) | ||
return a_dest | ||
end | ||
|
||
function contract!( | ||
alg::Algorithm, | ||
a_dest::AbstractArray, | ||
labels_dest, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β; | ||
kwargs..., | ||
) | ||
return error("Not implemented") | ||
end |
22 changes: 22 additions & 0 deletions
22
NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
function contract!( | ||
alg::Algorithm"matricize", | ||
a_dest::AbstractArray, | ||
biperm_dest::BipartitionedPermutation, | ||
a1::AbstractArray, | ||
biperm1::BipartitionedPermutation, | ||
a2::AbstractArray, | ||
biperm2::BipartitionedPermutation, | ||
α, | ||
β, | ||
) | ||
a_dest_matricized = matricize(a_dest, biperm_dest) | ||
a1_matricized = matricize(a1, biperm1) | ||
a2_matricized = matricize(a2, biperm2) | ||
mul!(a_dest_matricized, a1_matricized, a2_matricized, α, β) | ||
perm_dest = flatten(biperm_dest) | ||
# TODO: Create a function `unmatricize` or `unfusedims`. | ||
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest) | ||
a_dest_copy = reshape(a_dest_matricized, axes(a_dest)) | ||
permutedims!(a_dest, a_dest_copy, perm_dest) | ||
return a_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
function output_labels( | ||
f::typeof(contract), | ||
alg::Algorithm, | ||
a1::AbstractArray, | ||
labels1, | ||
a2::AbstractArray, | ||
labels2, | ||
α, | ||
β, | ||
) | ||
return output_labels(f, alg, labels1, labels2) | ||
end | ||
|
||
function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2) | ||
return output_labels(f, labels1, labels2) | ||
end | ||
|
||
function output_labels(::typeof(contract), labels1, labels2) | ||
return symdiff(labels1, labels2) | ||
end | ||
|
||
function bipartitioned_permutations( | ||
f::typeof(contract), alg::Algorithm, labels_dest, labels1, labels2 | ||
) | ||
return bipartitioned_permutations(f, labels_dest, labels1, labels2) | ||
end | ||
|
||
function bipartitioned_permutations(::typeof(contract), labels_dest, labels1, labels2) | ||
labels12 = (labels1..., labels2...) | ||
if isodd(length(labels12) - length(labels_dest)) | ||
error("Can't contract $labels1 and $labels2 to $labels_dest") | ||
end | ||
labels_contracted = unique(setdiff(labels12, labels_dest)) | ||
labels1_uncontracted = setdiff(labels1, labels_contracted) | ||
labels2_uncontracted = setdiff(labels2, labels_contracted) | ||
# Positions of labels. | ||
pos_dest_1 = map(l -> findfirst(isequal(l), labels_dest), labels1_uncontracted) | ||
pos_dest_2 = map(l -> findfirst(isequal(l), labels_dest), labels2_uncontracted) | ||
pos1_contracted = map(l -> findfirst(isequal(l), labels1), labels_contracted) | ||
pos2_contracted = map(l -> findfirst(isequal(l), labels2), labels_contracted) | ||
pos1_uncontracted = map(l -> findfirst(isequal(l), labels1), labels1_uncontracted) | ||
pos2_uncontracted = map(l -> findfirst(isequal(l), labels2), labels2_uncontracted) | ||
# Bipartitioned permutations. | ||
biperm_dest = BipartitionedPermutation(pos_dest_1, pos_dest_2) | ||
biperm1 = BipartitionedPermutation(pos1_uncontracted, pos1_contracted) | ||
biperm2 = BipartitionedPermutation(pos2_contracted, pos2_uncontracted) | ||
return biperm_dest, biperm1, biperm2 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
fuse(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) * length(a2)) | ||
fuse(a...) = foldl(fuse, a) | ||
|
||
matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...)) | ||
|
||
function matricize(a::AbstractArray, biperm::BipartitionedPermutation) | ||
# Permute and fuse the axes | ||
axes_src = axes(a) | ||
axes_codomain = map(i -> axes_src[i], biperm[1]) | ||
axes_domain = map(i -> axes_src[i], biperm[2]) | ||
axis_codomain_fused = fuse(axes_codomain...) | ||
axis_domain_fused = fuse(axes_domain...) | ||
# Permute the array | ||
perm = flatten(biperm) | ||
a_permuted = permutedims(a, perm) | ||
return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | ||
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" |
Oops, something went wrong.