-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BlockSparseArrays] Direct sum/
cat
(#1579)
* [BlockSparseArrays] Direct sum/`cat` * [NDTensors] Bump to v0.3.64
- Loading branch information
Showing
10 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "NDTensors" | ||
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" | ||
authors = ["Matthew Fishman <[email protected]>"] | ||
version = "0.3.63" | ||
version = "0.3.64" | ||
|
||
[deps] | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# TODO: Change to `AnyAbstractBlockSparseArray`. | ||
function Base.cat(as::BlockSparseArrayLike...; dims) | ||
# TODO: Use `sparse_cat` instead, currently | ||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
return blocksparse_cat(as...; dims) | ||
end |
26 changes: 26 additions & 0 deletions
26
NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths | ||
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat! | ||
|
||
# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`. | ||
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`. | ||
function SparseArrayInterface.axis_cat( | ||
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange | ||
) | ||
return blockedrange(vcat(blocklengths(a1), blocklengths(a2))) | ||
end | ||
|
||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) | ||
sparse_cat!(blocks(a_dest), blocks.(as)...; dims) | ||
return a_dest | ||
end | ||
|
||
# TODO: Delete this in favor of `sparse_cat`, currently | ||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
function blocksparse_cat(as::AbstractArray...; dims) | ||
a_dest = allocate_cat_output(as...; dims) | ||
blocksparse_cat!(a_dest, as...; dims) | ||
return a_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# TODO: Change to `AnyAbstractSparseArray`. | ||
function Base.cat(as::SparseArrayLike...; dims) | ||
return sparse_cat(as...; dims) | ||
end |
64 changes: 64 additions & 0 deletions
64
NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
unval(x) = x | ||
unval(::Val{x}) where {x} = x | ||
|
||
# TODO: Assert that `a1` and `a2` start at one. | ||
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) | ||
function axis_cat( | ||
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... | ||
) | ||
return axis_cat(axis_cat(a1, a2), a_rest...) | ||
end | ||
function cat_axes(as::AbstractArray...; dims) | ||
return ntuple(length(first(axes.(as)))) do dim | ||
return if dim in unval(dims) | ||
axis_cat(map(axes -> axes[dim], axes.(as))...) | ||
else | ||
axes(first(as))[dim] | ||
end | ||
end | ||
end | ||
|
||
function allocate_cat_output(as::AbstractArray...; dims) | ||
eltype_dest = promote_type(eltype.(as)...) | ||
axes_dest = cat_axes(as...; dims) | ||
# TODO: Promote the block types of the inputs rather than using | ||
# just the first input. | ||
# TODO: Make this customizable with `cat_similar`. | ||
# TODO: Base the zero element constructor on those of the inputs, | ||
# for example block sparse arrays. | ||
return similar(first(as), eltype_dest, axes_dest...) | ||
end | ||
|
||
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 | ||
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation | ||
# This is very similar to the `Base.cat` implementation but handles zero values better. | ||
function cat_offset!( | ||
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims | ||
) | ||
inds = ntuple(ndims(a_dest)) do dim | ||
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) | ||
end | ||
a_dest[inds...] = a1 | ||
new_offsets = ntuple(ndims(a_dest)) do dim | ||
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] | ||
end | ||
cat_offset!(a_dest, new_offsets, a_rest...; dims) | ||
return a_dest | ||
end | ||
function cat_offset!(a_dest::AbstractArray, offsets; dims) | ||
return a_dest | ||
end | ||
|
||
# TODO: Define a generic `cat!` function. | ||
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) | ||
offsets = ntuple(zero, ndims(a_dest)) | ||
# TODO: Fill `a_dest` with zeros if needed. | ||
cat_offset!(a_dest, offsets, as...; dims) | ||
return a_dest | ||
end | ||
|
||
function sparse_cat(as::AbstractArray...; dims) | ||
a_dest = allocate_cat_output(as...; dims) | ||
sparse_cat!(a_dest, as...; dims) | ||
return a_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57994ff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register subdir=NDTensors
57994ff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/119367
Tip: Release Notes
Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.
To add them here just re-invoke and the PR will be updated.
Tagging
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: