Skip to content

Commit

Permalink
[TensorAlgebra] Matricized QR tensor decomposition (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 17, 2023
1 parent 3f1afb8 commit c47eb7c
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module LinearAlgebraExtensions
using LinearAlgebra: LinearAlgebra, qr
using ..TensorAlgebra:
TensorAlgebra,
BipartitionedPermutation,
bipartition,
bipartitioned_permutations,
matricize,
unmatricize

include("qr.jl")
end
21 changes: 21 additions & 0 deletions NDTensors/src/TensorAlgebra/src/LinearAlgebraExtensions/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function LinearAlgebra.qr(a::AbstractArray, labels_a, labels_q, labels_r)
return qr(a, bipartitioned_permutations(qr, labels_a, labels_q, labels_r)...)
end

function LinearAlgebra.qr(a::AbstractArray, biperm::BipartitionedPermutation)
# TODO: Use a thin QR, define `qr_thin`.
a_matricized = matricize(a, biperm)
q_matricized, r_matricized = qr(a_matricized)
q_matricized_thin = typeof(a_matricized)(q_matricized)
axes_codomain, axes_domain = bipartition(axes(a), biperm)
q = unmatricize(q_matricized_thin, axes_codomain, (axes(q_matricized_thin, 2),))
r = unmatricize(r_matricized, (axes(r_matricized, 1),), axes_domain)
return q, r
end

function TensorAlgebra.bipartitioned_permutations(qr, labels_a, labels_q, labels_r)
# TODO: Use something like `findall`?
pos_q = map(l -> findfirst(isequal(l), labels_a), labels_q)
pos_r = map(l -> findfirst(isequal(l), labels_a), labels_r)
return (BipartitionedPermutation(pos_q, pos_r),)
end
1 change: 1 addition & 0 deletions NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")
include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl")
end
9 changes: 9 additions & 0 deletions NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ end
function flatten(biperm::BipartitionedPermutation)
return (biperm[1]..., biperm[2]...)
end

# Bipartition a vector according to the
# bipartitioned permutation.
function bipartition(v, biperm::BipartitionedPermutation)
# TODO: Use `TupleTools.getindices`.
v1 = map(i -> v[i], biperm[1])
v2 = map(i -> v[i], biperm[2])
return v1, v2
end
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function contract!(
perm_dest = flatten(biperm_dest)
# TODO: Create a function `unmatricize` or `unfusedims`.
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest)
a_dest_copy = reshape(a_dest_matricized, axes(a_dest))
permutedims!(a_dest, a_dest_copy, perm_dest)
a_dest_copy = reshape(a_dest_matricized, map(i -> axes(a_dest, i), perm_dest))
permutedims!(a_dest, a_dest_copy, invperm(perm_dest))
return a_dest
end
6 changes: 6 additions & 0 deletions NDTensors/src/TensorAlgebra/src/fusedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ fuse(a...) = foldl(fuse, a)

matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...))

# TODO: Make this more generic, i.e. for `BlockSparseArray`.
function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
# Permute and fuse the axes
axes_src = axes(a)
Expand All @@ -15,3 +16,8 @@ function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
a_permuted = permutedims(a, perm)
return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused))
end

# TODO: Make this more generic, i.e. for `BlockSparseArray`.
function unmatricize(a::AbstractArray, axes_codomain, axes_domain)
return reshape(a, (axes_codomain..., axes_domain...))
end
62 changes: 43 additions & 19 deletions NDTensors/src/TensorAlgebra/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,51 @@
using Combinatorics: permutations
using LinearAlgebra: qr
using NDTensors.TensorAlgebra: TensorAlgebra
using TensorOperations: TensorOperations
using Test: @test, @testset
using Test: @test, @test_broken, @testset

@testset "TensorAlgebra" begin
dims = (2, 3, 4, 5)
labels = (:a, :b, :c, :d)
for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4)))
a1 = randn(map(i -> dims[i], d1s))
labels1 = map(i -> labels[i], d1s)
a2 = randn(map(i -> dims[i], d2s))
labels2 = map(i -> labels[i], d2s)
for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2))
a1′ = permutedims(a1, perm1)
a2′ = permutedims(a2, perm2)
labels1′ = map(i -> labels1[i], perm1)
labels2′ = map(i -> labels2[i], perm2)
a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′)
@test labels_dest == symdiff(labels1′, labels2′)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1′, labels1′, a2′, labels2′
)
@test a_dest a_dest_tensoroperations
elts = (Float32, ComplexF32, Float64, ComplexF64)
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
dims = (2, 3, 4, 5)
labels = (:a, :b, :c, :d)
for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4)))
a1 = randn(elt1, map(i -> dims[i], d1s))
labels1 = map(i -> labels[i], d1s)
a2 = randn(elt2, map(i -> dims[i], d2s))
labels2 = map(i -> labels[i], d2s)
for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2))
a1′ = permutedims(a1, perm1)
a2′ = permutedims(a2, perm2)
labels1′ = map(i -> labels1[i], perm1)
labels2′ = map(i -> labels2[i], perm2)
a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′)
@test labels_dest == symdiff(labels1′, labels2′)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1′, labels1′, a2′, labels2′
)
@test a_dest a_dest_tensoroperations
end
end
end
@testset "contract broken" begin
a1 = randn(3, 5, 8)
a2 = randn(8, 2, 4)
labels_dest = (:a, :b, :c, :d)
labels1 = (:c, :a, :x)
labels2 = (:x, :d, :b)
@test_broken a′ = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
end
@testset "qr" begin
a = randn(5, 4, 3, 2)
labels_a = (:a, :b, :c, :d)
labels_q = (:b, :a)
labels_r = (:d, :c)
q, r = qr(a, labels_a, labels_q, labels_r)
label_qr = :qr
a′ = TensorAlgebra.contract(
labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
)
@test a a′
end
end

0 comments on commit c47eb7c

Please sign in to comment.