Skip to content
This repository has been archived by the owner on Jan 20, 2025. It is now read-only.

Commit

Permalink
Add support for expanding singleton dimensions (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 5, 2025
1 parent 3e63462 commit 7eb11cf
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 34 deletions.
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
name = "BroadcastMapConversion"
uuid = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.2"
version = "0.1.3"

[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

[compat]
BlockArrays = "1.3.0"
Compat = "4.16.0"
FillArrays = "1.13.0"
julia = "1.10"
151 changes: 120 additions & 31 deletions src/BroadcastMapConversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@ module BroadcastMapConversion

using Base.Broadcast:
Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate
using Compat: allequal

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

# Get the arguments of the map expression that
# is equivalent to the broadcast expression.
function map_args(bc::Broadcasted, rest...)
return (map_args(bc.args...)..., map_args(rest...)...)
function map_args(bc::Broadcasted)
return map_args_flatten(bc)
end
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
map_args(a, rest...) = map_args(rest...)
map_args() = ()

function map_args_flatten(bc::Broadcasted, args_rest...)
return (map_args_flatten(bc.args...)..., map_args_flatten(args_rest...)...)
end
function map_args_flatten(arg1::AbstractArray, args_rest...)
return (arg1, map_args_flatten(args_rest...)...)
end
map_args_flatten(arg1, args_rest...) = map_args_flatten(args_rest...)
map_args_flatten() = ()

struct MapFunction{F,Args<:Tuple} <: Function
f::F
Expand All @@ -27,54 +34,136 @@ struct Arg end
# is equivalent to the broadcast expression.
# Returns a `MapFunction`.
function map_function(bc::Broadcasted)
args = map_function_tuple(bc.args)
return MapFunction(bc.f, args)
return map_function_arg(bc)
end
map_function_args(args::Tuple{}) = args
function map_function_args(args::Tuple)
return (map_function_arg(args[1]), map_function_args(Base.tail(args))...)
end
map_function_tuple(t::Tuple{}) = t
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...)
map_function(a::WrappedScalarArgs) = a[]
map_function(a::AbstractArray) = Arg()
map_function(a) = a
function map_function_arg(bc::Broadcasted)
return MapFunction(bc.f, map_function_args(bc.args))
end
map_function_arg(a::WrappedScalarArgs) = a[]
map_function_arg(a::AbstractArray) = Arg()
map_function_arg(a) = a

# Evaluate MapFunction
(f::MapFunction)(args...) = apply(f, args)[1]
function apply(f::MapFunction, args)
args, newargs = apply_tuple(f.args, args)
return f.f(args...), newargs
(f::MapFunction)(args...) = apply_arg(f, args)[1]
function apply_arg(f::MapFunction, args)
mapfunction_args, args′ = apply_args(f.args, args)
return f.f(mapfunction_args...), args′
end
apply(a::Arg, args::Tuple) = args[1], Base.tail(args)
apply(a, args) = a, args
apply_tuple(t::Tuple{}, args) = t, args
function apply_tuple(t::Tuple, args)
t1, newargs1 = apply(t[1], args)
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
return (t1, ttail...), newargs
apply_arg(mapfunction_arg::Arg, args) = args[1], Base.tail(args)
apply_arg(mapfunction_arg, args) = mapfunction_arg, args
function apply_args(mapfunction_args::Tuple, args)
mapfunction_args1, args′ = apply_arg(mapfunction_args[1], args)
mapfunction_args_rest, args′′ = apply_args(Base.tail(mapfunction_args), args′)
return (mapfunction_args1, mapfunction_args_rest...), args′′
end
apply_args(mapfunction_args::Tuple{}, args) = mapfunction_args, args

is_map_expr_or_arg(arg::AbstractArray) = true
is_map_expr_or_arg(arg::Any) = false
function is_map_expr_or_arg(bc::Broadcasted)
return all(is_map_expr_or_arg, bc.args)
end
function is_map_expr(bc::Broadcasted)
return is_map_expr_or_arg(bc)
end

abstract type ExprStyle end
struct MapExpr <: ExprStyle end
struct NotMapExpr <: ExprStyle end

ExprStyle(bc::Broadcasted) = is_map_expr(bc) ? MapExpr() : NotMapExpr()

abstract type AbstractMapped <: Base.AbstractBroadcasted end

function check_shape(::Type{Bool}, args...)
return allequal(axes, args)
end
function check_shape(args...)
if !check_shape(Bool, args...)
throw(DimensionMismatch("Mismatched shapes $(axes.(args))."))
end
return nothing
end

# Promote the shape of the arguments to support broadcasting
# over dimensions by expanding singleton dimensions.
function promote_shape(ax, args::AbstractArray...)
if allequal((ax, axes.(args)...))
return args
end
return promote_shape_tile(ax, args...)
end
function promote_shape_tile(common_axes, args::AbstractArray...)
return map(arg -> tile_to_shape(arg, common_axes), args)
end

using BlockArrays: mortar
using FillArrays: Fill

# Extend by repeating value up to length.
function extend(t::Tuple, value, length)
return ntuple(i -> get(t, i, value), length)
end

# Handles logic of expanding singleton dimensions
# to match an array shape in broadcasting.
function tile_to_shape(a::AbstractArray, ax)
axes(a) == ax && return a
# Must be one-based for now.
@assert all(isone, first.(ax))
@assert all(isone, first.(axes(a)))
ndim = length(ax)
size′ = extend(size(a), 1, ndim)
a′ = reshape(a, size′)
target_size = length.(ax)
fillsize = ntuple(ndim) do dim
size′[dim] == target_size[dim] && return 1
isone(size′[dim]) && return target_size[dim]
return throw(DimensionMismatch("Dimensions $(axes(a)) and $ax don't match."))
end
return mortar(Fill(a′, fillsize))
end

struct Mapped{Style<:Union{Nothing,BroadcastStyle},Axes,F,Args<:Tuple} <: AbstractMapped
style::Style
f::F
args::Args
axes::Axes
function Mapped(style, f, args, axes)
check_shape(args...)
return new{typeof(style),typeof(axes),typeof(f),typeof(args)}(style, f, args, axes)
end
end

function Mapped(bc::Broadcasted)
return Mapped(bc.style, map_function(bc), map_args(bc), bc.axes)
return Mapped(ExprStyle(bc), bc)
end
function Mapped(::NotMapExpr, bc::Broadcasted)
f = map_function(bc)
ax = axes(bc)
args = promote_shape(ax, map_args(bc)...)
return Mapped(bc.style, f, args, ax)
end
function Mapped(::MapExpr, bc::Broadcasted)
f = bc.f
ax = axes(bc)
args = promote_shape(ax, bc.args...)
return Mapped(bc.style, f, args, ax)
end

function Broadcast.Broadcasted(m::Mapped)
return Broadcasted(m.style, m.f, m.args, m.axes)
end

# Convert `Broadcasted` to `Mapped` when `Broadcasted`
# is known to already be a map expression.
function map_broadcast_to_mapped(bc::Broadcasted)
return Mapped(bc.style, bc.f, bc.args, bc.axes)
function mapped(f, args...)
check_shape(args...)
return Mapped(broadcasted(f, args...))
end

mapped(f, args...) = Mapped(broadcasted(f, args...))

Base.similar(m::Mapped, elt::Type) = similar(Broadcasted(m), elt)
Base.similar(m::Mapped, elt::Type, ax::Tuple) = similar(Broadcasted(m), elt, ax)
Base.axes(m::Mapped) = axes(Broadcasted(m))
Expand All @@ -88,6 +177,6 @@ function Base.copy(m::Mapped)
return copyto!(similar(m, elt), m)
end
Base.copyto!(dest::AbstractArray, m::Mapped) = map!(m.f, dest, m.args...)
Broadcast.instantiate(m::Mapped) = map_broadcast_to_mapped(instantiate(Broadcasted(m)))
Broadcast.instantiate(m::Mapped) = Mapped(instantiate(Broadcasted(m)))

end
26 changes: 24 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Base.Broadcast: broadcasted
using BroadcastMapConversion: Mapped, mapped
using Test: @test, @testset
using BroadcastMapConversion: Mapped, is_map_expr, mapped
using Test: @inferred, @test, @test_throws, @testset

@testset "BroadcastMapConversion (eltype=$elt)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
Expand All @@ -22,4 +22,26 @@ using Test: @test, @testset
@test copyto!(similar(m, elt), m) ref
@test copyto!(similar(m′, elt), m) ref
end

@test @inferred is_map_expr(
Broadcast.broadcasted(+, [2], Broadcast.broadcasted(sin, [2]))
)
@test @inferred !is_map_expr(Broadcast.broadcasted(+, 2, Broadcast.broadcasted(sin, [2])))

# Logic handling singleton dimensions in broadcasting.
for (a, b) in (
(randn(elt, 2, 2), randn(elt, 2)),
(randn(elt, 2, 2), randn(elt, 1, 2)),
(randn(elt, 2, 1), randn(elt, 1, 2)),
(randn(elt, 2, 2, 2), randn(elt, 2)),
(randn(elt, 2, 2, 2), randn(elt, 1, 2)),
(randn(elt, 2, 2, 2), randn(elt, 1, 1, 2)),
(randn(elt, 2, 2, 2), randn(elt, 2, 2)),
(randn(elt, 2, 2, 2), randn(elt, 1, 2, 2)),
)
@test_throws DimensionMismatch mapped(+, a, b)
bc = broadcasted(+, a, b)
m = Mapped(bc)
@test copy(m) == copy(bc)
end
end

0 comments on commit 7eb11cf

Please sign in to comment.