diff --git a/Project.toml b/Project.toml index 6f15963..c9ffe6e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,15 @@ name = "BroadcastMapConversion" uuid = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.3" + +[deps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" [compat] +BlockArrays = "1.3.0" +Compat = "4.16.0" +FillArrays = "1.13.0" julia = "1.10" diff --git a/src/BroadcastMapConversion.jl b/src/BroadcastMapConversion.jl index 0b6bb5e..d59ae21 100644 --- a/src/BroadcastMapConversion.jl +++ b/src/BroadcastMapConversion.jl @@ -5,17 +5,24 @@ module BroadcastMapConversion using Base.Broadcast: Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate +using Compat: allequal const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} # Get the arguments of the map expression that # is equivalent to the broadcast expression. -function map_args(bc::Broadcasted, rest...) - return (map_args(bc.args...)..., map_args(rest...)...) +function map_args(bc::Broadcasted) + return map_args_flatten(bc) end -map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...) -map_args(a, rest...) = map_args(rest...) -map_args() = () + +function map_args_flatten(bc::Broadcasted, args_rest...) + return (map_args_flatten(bc.args...)..., map_args_flatten(args_rest...)...) +end +function map_args_flatten(arg1::AbstractArray, args_rest...) + return (arg1, map_args_flatten(args_rest...)...) +end +map_args_flatten(arg1, args_rest...) = map_args_flatten(args_rest...) +map_args_flatten() = () struct MapFunction{F,Args<:Tuple} <: Function f::F @@ -27,54 +34,136 @@ struct Arg end # is equivalent to the broadcast expression. # Returns a `MapFunction`. function map_function(bc::Broadcasted) - args = map_function_tuple(bc.args) - return MapFunction(bc.f, args) + return map_function_arg(bc) +end +map_function_args(args::Tuple{}) = args +function map_function_args(args::Tuple) + return (map_function_arg(args[1]), map_function_args(Base.tail(args))...) end -map_function_tuple(t::Tuple{}) = t -map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...) -map_function(a::WrappedScalarArgs) = a[] -map_function(a::AbstractArray) = Arg() -map_function(a) = a +function map_function_arg(bc::Broadcasted) + return MapFunction(bc.f, map_function_args(bc.args)) +end +map_function_arg(a::WrappedScalarArgs) = a[] +map_function_arg(a::AbstractArray) = Arg() +map_function_arg(a) = a # Evaluate MapFunction -(f::MapFunction)(args...) = apply(f, args)[1] -function apply(f::MapFunction, args) - args, newargs = apply_tuple(f.args, args) - return f.f(args...), newargs +(f::MapFunction)(args...) = apply_arg(f, args)[1] +function apply_arg(f::MapFunction, args) + mapfunction_args, args′ = apply_args(f.args, args) + return f.f(mapfunction_args...), args′ end -apply(a::Arg, args::Tuple) = args[1], Base.tail(args) -apply(a, args) = a, args -apply_tuple(t::Tuple{}, args) = t, args -function apply_tuple(t::Tuple, args) - t1, newargs1 = apply(t[1], args) - ttail, newargs = apply_tuple(Base.tail(t), newargs1) - return (t1, ttail...), newargs +apply_arg(mapfunction_arg::Arg, args) = args[1], Base.tail(args) +apply_arg(mapfunction_arg, args) = mapfunction_arg, args +function apply_args(mapfunction_args::Tuple, args) + mapfunction_args1, args′ = apply_arg(mapfunction_args[1], args) + mapfunction_args_rest, args′′ = apply_args(Base.tail(mapfunction_args), args′) + return (mapfunction_args1, mapfunction_args_rest...), args′′ end +apply_args(mapfunction_args::Tuple{}, args) = mapfunction_args, args + +is_map_expr_or_arg(arg::AbstractArray) = true +is_map_expr_or_arg(arg::Any) = false +function is_map_expr_or_arg(bc::Broadcasted) + return all(is_map_expr_or_arg, bc.args) +end +function is_map_expr(bc::Broadcasted) + return is_map_expr_or_arg(bc) +end + +abstract type ExprStyle end +struct MapExpr <: ExprStyle end +struct NotMapExpr <: ExprStyle end + +ExprStyle(bc::Broadcasted) = is_map_expr(bc) ? MapExpr() : NotMapExpr() abstract type AbstractMapped <: Base.AbstractBroadcasted end +function check_shape(::Type{Bool}, args...) + return allequal(axes, args) +end +function check_shape(args...) + if !check_shape(Bool, args...) + throw(DimensionMismatch("Mismatched shapes $(axes.(args)).")) + end + return nothing +end + +# Promote the shape of the arguments to support broadcasting +# over dimensions by expanding singleton dimensions. +function promote_shape(ax, args::AbstractArray...) + if allequal((ax, axes.(args)...)) + return args + end + return promote_shape_tile(ax, args...) +end +function promote_shape_tile(common_axes, args::AbstractArray...) + return map(arg -> tile_to_shape(arg, common_axes), args) +end + +using BlockArrays: mortar +using FillArrays: Fill + +# Extend by repeating value up to length. +function extend(t::Tuple, value, length) + return ntuple(i -> get(t, i, value), length) +end + +# Handles logic of expanding singleton dimensions +# to match an array shape in broadcasting. +function tile_to_shape(a::AbstractArray, ax) + axes(a) == ax && return a + # Must be one-based for now. + @assert all(isone, first.(ax)) + @assert all(isone, first.(axes(a))) + ndim = length(ax) + size′ = extend(size(a), 1, ndim) + a′ = reshape(a, size′) + target_size = length.(ax) + fillsize = ntuple(ndim) do dim + size′[dim] == target_size[dim] && return 1 + isone(size′[dim]) && return target_size[dim] + return throw(DimensionMismatch("Dimensions $(axes(a)) and $ax don't match.")) + end + return mortar(Fill(a′, fillsize)) +end + struct Mapped{Style<:Union{Nothing,BroadcastStyle},Axes,F,Args<:Tuple} <: AbstractMapped style::Style f::F args::Args axes::Axes + function Mapped(style, f, args, axes) + check_shape(args...) + return new{typeof(style),typeof(axes),typeof(f),typeof(args)}(style, f, args, axes) + end end function Mapped(bc::Broadcasted) - return Mapped(bc.style, map_function(bc), map_args(bc), bc.axes) + return Mapped(ExprStyle(bc), bc) +end +function Mapped(::NotMapExpr, bc::Broadcasted) + f = map_function(bc) + ax = axes(bc) + args = promote_shape(ax, map_args(bc)...) + return Mapped(bc.style, f, args, ax) end +function Mapped(::MapExpr, bc::Broadcasted) + f = bc.f + ax = axes(bc) + args = promote_shape(ax, bc.args...) + return Mapped(bc.style, f, args, ax) +end + function Broadcast.Broadcasted(m::Mapped) return Broadcasted(m.style, m.f, m.args, m.axes) end -# Convert `Broadcasted` to `Mapped` when `Broadcasted` -# is known to already be a map expression. -function map_broadcast_to_mapped(bc::Broadcasted) - return Mapped(bc.style, bc.f, bc.args, bc.axes) +function mapped(f, args...) + check_shape(args...) + return Mapped(broadcasted(f, args...)) end -mapped(f, args...) = Mapped(broadcasted(f, args...)) - Base.similar(m::Mapped, elt::Type) = similar(Broadcasted(m), elt) Base.similar(m::Mapped, elt::Type, ax::Tuple) = similar(Broadcasted(m), elt, ax) Base.axes(m::Mapped) = axes(Broadcasted(m)) @@ -88,6 +177,6 @@ function Base.copy(m::Mapped) return copyto!(similar(m, elt), m) end Base.copyto!(dest::AbstractArray, m::Mapped) = map!(m.f, dest, m.args...) -Broadcast.instantiate(m::Mapped) = map_broadcast_to_mapped(instantiate(Broadcasted(m))) +Broadcast.instantiate(m::Mapped) = Mapped(instantiate(Broadcasted(m))) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4ad29d0..8e8d174 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,6 @@ using Base.Broadcast: broadcasted -using BroadcastMapConversion: Mapped, mapped -using Test: @test, @testset +using BroadcastMapConversion: Mapped, is_map_expr, mapped +using Test: @inferred, @test, @test_throws, @testset @testset "BroadcastMapConversion (eltype=$elt)" for elt in ( Float32, Float64, Complex{Float32}, Complex{Float64} @@ -22,4 +22,26 @@ using Test: @test, @testset @test copyto!(similar(m, elt), m) ≈ ref @test copyto!(similar(m′, elt), m) ≈ ref end + + @test @inferred is_map_expr( + Broadcast.broadcasted(+, [2], Broadcast.broadcasted(sin, [2])) + ) + @test @inferred !is_map_expr(Broadcast.broadcasted(+, 2, Broadcast.broadcasted(sin, [2]))) + + # Logic handling singleton dimensions in broadcasting. + for (a, b) in ( + (randn(elt, 2, 2), randn(elt, 2)), + (randn(elt, 2, 2), randn(elt, 1, 2)), + (randn(elt, 2, 1), randn(elt, 1, 2)), + (randn(elt, 2, 2, 2), randn(elt, 2)), + (randn(elt, 2, 2, 2), randn(elt, 1, 2)), + (randn(elt, 2, 2, 2), randn(elt, 1, 1, 2)), + (randn(elt, 2, 2, 2), randn(elt, 2, 2)), + (randn(elt, 2, 2, 2), randn(elt, 1, 2, 2)), + ) + @test_throws DimensionMismatch mapped(+, a, b) + bc = broadcasted(+, a, b) + m = Mapped(bc) + @test copy(m) == copy(bc) + end end