diff --git a/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl b/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl new file mode 100644 index 0000000000..e63cc8f4f8 --- /dev/null +++ b/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl @@ -0,0 +1,12 @@ +module LinearAlgebraExtensions +using LinearAlgebra: LinearAlgebra, qr +using ..TensorAlgebra: + TensorAlgebra, + BipartitionedPermutation, + bipartition, + bipartitioned_permutations, + matricize, + unmatricize + +include("qr.jl") +end diff --git a/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/qr.jl b/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/qr.jl new file mode 100644 index 0000000000..c058041dc3 --- /dev/null +++ b/NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/qr.jl @@ -0,0 +1,21 @@ +function LinearAlgebra.qr(a::AbstractArray, labels_a, labels_q, labels_r) + return qr(a, bipartitioned_permutations(qr, labels_a, labels_q, labels_r)...) +end + +function LinearAlgebra.qr(a::AbstractArray, biperm::BipartitionedPermutation) + # TODO: Use a thin QR, define `qr_thin`. + a_matricized = matricize(a, biperm) + q_matricized, r_matricized = qr(a_matricized) + q_matricized_thin = typeof(a_matricized)(q_matricized) + axes_codomain, axes_domain = bipartition(axes(a), biperm) + q = unmatricize(q_matricized_thin, axes_codomain, (axes(q_matricized_thin, 2),)) + r = unmatricize(r_matricized, (axes(r_matricized, 1),), axes_domain) + return q, r +end + +function TensorAlgebra.bipartitioned_permutations(qr, labels_a, labels_q, labels_r) + # TODO: Use something like `findall`? + pos_q = map(l -> findfirst(isequal(l), labels_a), labels_q) + pos_r = map(l -> findfirst(isequal(l), labels_a), labels_r) + return (BipartitionedPermutation(pos_q, pos_r),) +end diff --git a/NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl b/NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl index 7fdc62fcac..dcd053f944 100644 --- a/NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl +++ b/NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl @@ -8,4 +8,5 @@ include("contract/contract.jl") include("contract/output_labels.jl") include("contract/allocate_output.jl") include("contract/contract_matricize/contract.jl") +include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl") end diff --git a/NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl b/NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl index 19f56d1919..88b65e888f 100644 --- a/NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl +++ b/NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl @@ -15,3 +15,12 @@ end function flatten(biperm::BipartitionedPermutation) return (biperm[1]..., biperm[2]...) end + +# Bipartition a vector according to the +# bipartitioned permutation. +function bipartition(v, biperm::BipartitionedPermutation) + # TODO: Use `TupleTools.getindices`. + v1 = map(i -> v[i], biperm[1]) + v2 = map(i -> v[i], biperm[2]) + return v1, v2 +end diff --git a/NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl b/NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl index 9b212f7485..959e09598e 100644 --- a/NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl +++ b/NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl @@ -16,7 +16,7 @@ function contract!( perm_dest = flatten(biperm_dest) # TODO: Create a function `unmatricize` or `unfusedims`. # unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest) - a_dest_copy = reshape(a_dest_matricized, axes(a_dest)) - permutedims!(a_dest, a_dest_copy, perm_dest) + a_dest_copy = reshape(a_dest_matricized, map(i -> axes(a_dest, i), perm_dest)) + permutedims!(a_dest, a_dest_copy, invperm(perm_dest)) return a_dest end diff --git a/NDTensors/src/TensorAlgebra/src/fusedims.jl b/NDTensors/src/TensorAlgebra/src/fusedims.jl index 47d0dafbf9..39220dd687 100644 --- a/NDTensors/src/TensorAlgebra/src/fusedims.jl +++ b/NDTensors/src/TensorAlgebra/src/fusedims.jl @@ -3,6 +3,7 @@ fuse(a...) = foldl(fuse, a) matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...)) +# TODO: Make this more generic, i.e. for `BlockSparseArray`. function matricize(a::AbstractArray, biperm::BipartitionedPermutation) # Permute and fuse the axes axes_src = axes(a) @@ -15,3 +16,8 @@ function matricize(a::AbstractArray, biperm::BipartitionedPermutation) a_permuted = permutedims(a, perm) return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused)) end + +# TODO: Make this more generic, i.e. for `BlockSparseArray`. +function unmatricize(a::AbstractArray, axes_codomain, axes_domain) + return reshape(a, (axes_codomain..., axes_domain...)) +end diff --git a/NDTensors/src/TensorAlgebra/test/runtests.jl b/NDTensors/src/TensorAlgebra/test/runtests.jl index a21a108ee9..d5182591a8 100644 --- a/NDTensors/src/TensorAlgebra/test/runtests.jl +++ b/NDTensors/src/TensorAlgebra/test/runtests.jl @@ -1,27 +1,51 @@ using Combinatorics: permutations +using LinearAlgebra: qr using NDTensors.TensorAlgebra: TensorAlgebra using TensorOperations: TensorOperations -using Test: @test, @testset +using Test: @test, @test_broken, @testset @testset "TensorAlgebra" begin - dims = (2, 3, 4, 5) - labels = (:a, :b, :c, :d) - for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4))) - a1 = randn(map(i -> dims[i], d1s)) - labels1 = map(i -> labels[i], d1s) - a2 = randn(map(i -> dims[i], d2s)) - labels2 = map(i -> labels[i], d2s) - for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2)) - a1′ = permutedims(a1, perm1) - a2′ = permutedims(a2, perm2) - labels1′ = map(i -> labels1[i], perm1) - labels2′ = map(i -> labels2[i], perm2) - a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′) - @test labels_dest == symdiff(labels1′, labels2′) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest, a1′, labels1′, a2′, labels2′ - ) - @test a_dest ≈ a_dest_tensoroperations + elts = (Float32, ComplexF32, Float64, ComplexF64) + @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + dims = (2, 3, 4, 5) + labels = (:a, :b, :c, :d) + for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4))) + a1 = randn(elt1, map(i -> dims[i], d1s)) + labels1 = map(i -> labels[i], d1s) + a2 = randn(elt2, map(i -> dims[i], d2s)) + labels2 = map(i -> labels[i], d2s) + for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2)) + a1′ = permutedims(a1, perm1) + a2′ = permutedims(a2, perm2) + labels1′ = map(i -> labels1[i], perm1) + labels2′ = map(i -> labels2[i], perm2) + a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′) + @test labels_dest == symdiff(labels1′, labels2′) + a_dest_tensoroperations = TensorOperations.tensorcontract( + labels_dest, a1′, labels1′, a2′, labels2′ + ) + @test a_dest ≈ a_dest_tensoroperations + end end end + @testset "contract broken" begin + a1 = randn(3, 5, 8) + a2 = randn(8, 2, 4) + labels_dest = (:a, :b, :c, :d) + labels1 = (:c, :a, :x) + labels2 = (:x, :d, :b) + @test_broken a′ = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2) + end + @testset "qr" begin + a = randn(5, 4, 3, 2) + labels_a = (:a, :b, :c, :d) + labels_q = (:b, :a) + labels_r = (:d, :c) + q, r = qr(a, labels_a, labels_q, labels_r) + label_qr = :qr + a′ = TensorAlgebra.contract( + labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...) + ) + @test a ≈ a′ + end end