Skip to content

Commit

Permalink
improve getall/setall for staticarrays (#108)
Browse files Browse the repository at this point in the history
* getall() on static arrays returns tuples
* fix hetero tuples
* keep static arrays in setall()
* fix tests on 1.9-
* test different eltypes
  • Loading branch information
aplavin authored May 17, 2023
1 parent 29a4537 commit 79ee865
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ AxisKeys = "0.1,0.2" # extension tests only pass for AxisKeys 0.2, but 0.1 comp
Compat = "3.18, 4"
CompositionsBase = "0.1"
ConstructionBase = "1.5"
ConstructionBaseExtras = "0.1"
IntervalSets = "0.5,0.7" # extension tests only pass for IntervalSets 0.7, but 0.5 compat is required to run tests on Julia 1.3
InverseFunctions = "0.1.5"
MacroTools = "0.4.4, 0.5"
Expand All @@ -44,6 +45,7 @@ julia = "1.3"
[extras]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ConstructionBaseExtras = "914cd950-b775-4282-9f32-54fc4544c321"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Expand All @@ -57,4 +59,5 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

[targets]
test = ["Test", "PerformanceTestTools", "QuickTypes", "InverseFunctions", "AxisKeys", "IntervalSets", "StaticArrays", "StructArrays", "BenchmarkTools", "InteractiveUtils", "StaticNumbers", "Unitful","Aqua"]
test = ["Test", "PerformanceTestTools", "QuickTypes", "InverseFunctions", "AxisKeys", "IntervalSets", "StaticArrays", "StructArrays", "BenchmarkTools", "InteractiveUtils", "StaticNumbers", "Unitful", "Aqua", "ConstructionBaseExtras"]

4 changes: 4 additions & 0 deletions ext/AccessorsStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ import Accessors: setindex, delete, insert

Accessors.set(obj::StaticArrays.SVector, ::Type{Tuple}, val::Tuple) = StaticArrays.SVector(val)

Accessors.getall(obj::StaticArrays.StaticArray, ::Elements) = Tuple(obj)
Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs::AbstractArray) = constructorof(typeof(obj))(vs...) # just for disambiguation
Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs) = constructorof(typeof(obj))(vs...)

end
2 changes: 1 addition & 1 deletion src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs))
_reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs)
_reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)

_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
_staticlength(::NTuple{N, Any}) where {N} = Val(N)
_staticlength(x::AbstractVector) = length(x)

getall_lengths(obj, optics::Tuple{Any}) = _staticlength(getall(obj, only(optics)))
Expand Down
31 changes: 24 additions & 7 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using Accessors
using StaticNumbers
using StaticArrays

@static if VERSION < v"1.9-"
# for StaticArrays constructorof: ConstructionBase itself only supports it through 1.9 extension
using ConstructionBaseExtras
end

if VERSION >= v"1.6" # for ComposedFunction
@testset "getall" begin
Expand Down Expand Up @@ -37,17 +41,16 @@ if VERSION >= v"1.6" # for ComposedFunction
@test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only)

# trickier types for Elements():
obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1), SVector(2, 3)))
obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1.), SVector(2, 3)))
@test ['b', 'c', 'd'] == @inferred getall(obj, @optic _.a |> Elements() |> Elements() |> _ + 1)
@test [2, 4, 3, 5] == @inferred getall(obj, @optic _.b |> Elements() |> Elements() |> _ + 1)
@test SVector(1, 2, 3) === @inferred getall(obj, @optic _.c |> Elements() |> Elements())
@test [2, 3, 4] == @inferred getall(obj, @optic _.c |> Elements() |> Elements() |> _ + 1)
@test_broken SVector(2, 3, 4) === getall(obj, @optic _.c |> Elements() |> Elements() |> _ + 1)
@test (1., 2, 3) === @inferred getall(obj, @optic _.c |> Elements() |> Elements())
@test (2., 3, 4) === @inferred getall(obj, @optic _.c |> Elements() |> Elements() |> _ + 1)

# composition order should not matter:
@test [2, 3, 4] == @inferred getall(obj, (@optic(_ + 1) Elements() Elements()) @optic(_.c))
@test [2, 3, 4] == @inferred getall(obj, (@optic(_ + 1) Elements()) (Elements() @optic(_.c)))
@test [2, 3, 4] == @inferred getall(obj, @optic(_ + 1) (Elements() Elements() @optic(_.c)))
@test (2., 3, 4) === @inferred getall(obj, (@optic(_ + 1) Elements() Elements()) @optic(_.c))
@test (2., 3, 4) === @inferred getall(obj, (@optic(_ + 1) Elements()) (Elements() @optic(_.c)))
@test (2., 3, 4) === @inferred getall(obj, @optic(_ + 1) (Elements() Elements() @optic(_.c)))

obj = ()
@test () === @inferred getall(obj, @optic _ |> Elements() |> _ + 1)
Expand Down Expand Up @@ -114,6 +117,20 @@ end
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, (@optic(_ * 3) Properties()) (Elements() @optic(_.b)), [-9, -12, -15, -18])
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_ * 3) (Properties() Elements() @optic(_.b)), [-9, -12, -15, -18])

# SVectors and nested Elements:
obj = (c=(SVector(1.), SVector(2, 3)),)
@test setall(obj.c[1], Elements(), (5, 6)) === SVector(5, 6)
@test setall(obj.c[1], Elements(), (5,)) === SVector(5)
@test setall(obj.c[1], Elements(), [5, 6]) === SVector(5, 6)
@test setall(obj.c[1], Elements(), [5]) === SVector(5)
@testset for o in (
(@optic _.c |> Elements() |> Elements()),
(@optic _.c |> Elements() |> Elements() |> _ + 1),
)
@test setall(obj, o, getall(obj, o)) === obj
@test setall(obj, o, collect(getall(obj, o))) === obj
end

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
Expand Down

0 comments on commit 79ee865

Please sign in to comment.