From 9a157d682b1778a12358da57fed65e09587c9be9 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 17 Jan 2025 12:07:13 +0100 Subject: [PATCH] initial sketch of rand with types --- src/groups/special_euclidean_group.jl | 6 +++++ src/interface.jl | 35 ++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/groups/special_euclidean_group.jl b/src/groups/special_euclidean_group.jl index dc0265c9..4666d38d 100644 --- a/src/groups/special_euclidean_group.jl +++ b/src/groups/special_euclidean_group.jl @@ -102,6 +102,12 @@ struct SpecialEuclideanMatrixPoint{T} <: AbstractLieGroupPoint value::T end +function ManifoldsBase.allocate_on( + M::SpecialEuclideanGroup, ::Type{SpecialEuclideanMatrixPoint} +) + return SpecialEuclideanMatrixPoint(Matrix(undef, representation_size(M)...)) +end + """ SpecialEuclideanMatrixTVector <: AbstractLieGroupPoint diff --git a/src/interface.jl b/src/interface.jl index f415b0f7..8c3c6fde 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -904,8 +904,31 @@ Random.rand(::LieGroup; kwargs...) # New in LIeGroups – maybe move to ManifoldsBase at some point @doc "$(_doc_rand)" -function Random.rand(G::LieGroup, T::Type; vector_at=nothing, kwargs...) - # TODO conitnue here and allocate the right thing. +Random.rand(G::LieGroup, T::Type; vector_at=nothing, kwargs...) + +function Random.rand(M::AbstractManifold, T::Type, d::Integer; kwargs...) + return [rand(M, T; kwargs...) for _ in 1:d] +end +function Random.rand(rng::AbstractRNG, M::AbstractManifold, T::Type, d::Integer; kwargs...) + return [rand(rng, M, T; kwargs...) for _ in 1:d] +end +function Random.rand(M::AbstractManifold, T; vector_at=nothing, kwargs...) + if vector_at === nothing + pX = allocate_on(M, T) + else + pX = allocate_on(M, TangentSpaceType(), T) + end + rand!(M, pX; vector_at=vector_at, kwargs...) + return pX +end +function Random.rand(rng::AbstractRNG, M::AbstractManifold, T; vector_at=nothing, kwargs...) + if vector_at === nothing + pX = allocate_on(M, T) + else + pX = allocate_on(M, TangentSpaceType(), T) + end + rand!(rng, M, pX; vector_at=vector_at, kwargs...) + return pX end @doc "$(_doc_rand)" @@ -997,6 +1020,12 @@ end # # Allocation hints - mainly pass-through, especially for power manifolds + +ManifoldsBase.allocate_on(G::LieGroup, T::Type) = ManifoldsBase.allocate_on(G.manifold, T) +function ManifoldsBase.allocate_on(M::LieGroup, T::Type{<:AbstractArray}) + return ManifoldsBase.allocate_on(M.manifold, T) +end + function ManifoldsBase.allocate_result( G::LieGroup, f::Union{typeof(compose),typeof(inv),typeof(conjugate),typeof(exp)}, @@ -1032,7 +1061,7 @@ macro default_lie_group_fallbacks(TG, TP, TV, pfield::Symbol, vfield::Symbol) return LieGroups.adjoint(M, g.$pfield, X.$vfield) end - function LieGroups.adjoint!(M::$TM, Y::$TV, g::$TP, X::$TV) + function LieGroups.adjoint!(M::$TG, Y::$TV, g::$TP, X::$TV) LieGroups.adjoint!(M, Y.$vfield, g.$pfield, X.$vfield) return Y end