Skip to content

Commit

Permalink
Euler angles
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Apr 19, 2024
1 parent c406cb5 commit 8c7eaa0
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 7 deletions.
7 changes: 6 additions & 1 deletion docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ Mappings
.. automodule:: roma.mappings
:members:

.. automodule:: roma.euler
:members:

Utils
----------
.. automodule:: roma.utils
Expand Down Expand Up @@ -225,12 +228,14 @@ From source repository::

License
=======
*RoMa*, Copyright (c) 2021 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see `license <https://github.com/naver/roma/blob/master/LICENSE>`_).
*RoMa*, Copyright (c) 2020 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see `license <https://github.com/naver/roma/blob/master/LICENSE>`_).

Bits of code were adapted from SciPy. Documentation is generated, distributed and displayed with the support of Sphinx and other materials (see `notice <https://github.com/naver/roma/blob/master/NOTICE>`_).

Changelog
==========
Version 1.5.0:
- Added Euler angles mappings.
Version 1.4.5:
- 3-Clause BSD Licensing.
Version 1.4.4:
Expand Down
1 change: 1 addition & 0 deletions roma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .mappings import *
from .utils import *
from .transforms import *
from .euler import *
229 changes: 229 additions & 0 deletions roma/euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# RoMa
# Copyright (c) 2020 NAVER Corp.
# 3-Clause BSD License.
import torch
import roma
import numpy as np
import roma.internal

def _elementary_basis_index(axis):
"""
Return the index corresponding to a given axis label.
"""
if axis == 'x':
return 0
elif axis == 'y':
return 1
elif axis == 'z':
return 2
else:
raise ValueError("Invalid axis.")

def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=True, dtype=None, device=None):
"""
Convert Euler angles to unit quaternion representation.
Args:
convention (string): string defining a sequence of rotation axes ('XYZ' or 'xzx' for example).
The sequence of rotation is expressed either with respect to a global 'extrinsic' coordinate system (in which case axes are denoted in lowercase: 'x', 'y', or 'z'),
or with respect to an 'intrinsic' coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: 'X', 'Y', 'Z').
Intrinsic and extrinsic conventions cannot be mixed.
angles (list of floats or list of tensors): a list of angles associated to each axis, expressed in radians by default.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
A batch of unit quaternions (...x4 tensor, XYZW convention).
"""
assert len(convention) == len(angles)

extrinsics = convention.islower()
if extrinsics:
# Cast from intrinsics to extrinsics convention
convention = convention.upper()[::-1]
angles = angles[::-1]

unitquats = []
for axis, angle in zip(convention, angles):
angle = torch.as_tensor(angle, device=device, dtype=dtype)
if degrees:
angle = torch.deg2rad(angle)
batch_shape = angle.shape
rotvec = torch.zeros(batch_shape + torch.Size((3,)), device=angle.device, dtype=angle.dtype)
if axis == 'X':
rotvec[...,0] = angle
elif axis == 'Y':
rotvec[...,1] = angle
elif axis == 'Z':
rotvec[...,2] = angle
else:
raise ValueError("Invalid convention (expected format: 'xyz', 'zxz', 'XYZ', etc.).")
q = roma.rotvec_to_unitquat(rotvec)
unitquats.append(q)
return roma.quat_composition(unitquats, normalize=normalize)

def euler_to_rotvec(convention: str, angles : list, degrees=False, dtype=None, device=None):
"""
Convert Euler angles to rotation vector representation.
Args:
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`.
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
a batch of rotation vectors (...x3 tensor).
"""
return roma.unitquat_to_rotvec(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device))

def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, device=None):
"""
Convert Euler angles to rotation matrix representation.
Args:
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`.
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
a batch of rotation matrices (...x3x3 tensor).
"""
return roma.unitquat_to_rotmat(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device))

def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7):
"""
Convert unit quaternion to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
quat (...x4 tensor, XYZW convention): input batch of unit quaternion.
degrees (bool): if True, returned angles are expressed in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians.
In case of gimbal lock, the third angle is arbitrarily set to 0.
"""
# Code adapted from scipy.spatial.transform.Rotation.
# Reference: https://github.com/scipy/scipy/blob/ac6bcaf00411286271f7cc21e495192c73168ae4/scipy/spatial/transform/_rotation.pyx#L325C12-L325C15
assert len(convention) == 3

pi = np.pi
lamb = np.pi/2

extrinsic = convention.islower()
if not extrinsic:
convention = convention.lower()[::-1]

quat, batch_shape = roma.internal.flatten_batch_dims(quat, end_dim=-2)
N = quat.shape[0]

i = _elementary_basis_index(convention[0])
j = _elementary_basis_index(convention[1])
k = _elementary_basis_index(convention[2])
assert i != j and j != k, "Consecutive axes should not be identical."

symmetric = (i == k)

if symmetric:
# Get third axis
k = 3 - i - j

# Step 0
# Check if permutation is even (+1) or odd (-1)
sign = (i - j) * (j - k) * (k - i) // 2

# Step 1
# Permutate quaternion elements
if symmetric:
a = quat[:,3]
b = quat[:,i]
c = quat[:,j]
d = quat[:,k] * sign
else:
a = quat[:,3] - quat[:,j]
b = quat[:,i] + quat[:,k] * sign
c = quat[:,j] + quat[:,3]
d = quat[:,k] * sign - quat[:,i]


# intrinsic/extrinsic conversion helpers
if extrinsic:
angle_first = 0
angle_third = 2
else:
angle_first = 2
angle_third = 0

# Step 2
# Compute second angle...
angles = [torch.empty(N, device=quat.device, dtype=quat.dtype) for _ in range(3)]

angles[1] = 2 * torch.atan2(torch.hypot(c, d), torch.hypot(a, b))

# ... and check if equal to is 0 or pi, causing a singularity
case1 = torch.abs(angles[1]) <= epsilon
case2 = torch.abs(angles[1] - pi) <= epsilon
case1or2 = torch.logical_or(case1, case2)
# Step 3
# compute first and third angles, according to case
half_sum = torch.atan2(b, a)
half_diff = torch.atan2(d, c)

# no singularities
angles[angle_first] = half_sum - half_diff
angles[angle_third] = half_sum + half_diff

# any degenerate case
angles[2][case1or2] = 0
angles[0][case1] = 2 * half_sum[case1]
angles[0][case2] = 2 * (-1 if extrinsic else 1) * half_diff[case2]

# for Tait-Bryan/asymmetric sequences
if not symmetric:
angles[angle_third] *= sign
angles[1] -= lamb

for idx in range(3):
foo = angles[idx]
foo[foo < -pi] += 2 * pi
foo[foo > pi] -= 2 * pi
if degrees:
foo = torch.rad2deg(foo)
angles[idx] = roma.internal.unflatten_batch_dims(foo, batch_shape)

return angles

def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7):
"""
Convert rotation vector to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
rotvec (...x3 tensor): input batch of rotation vectors.
degrees (bool): if True, returned angles are expressed in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians.
In case of gimbal lock, the third angle is arbitrarily set to 0.
"""
return unitquat_to_euler(convention, roma.rotvec_to_unitquat(rotvec), degrees=degrees, epsilon=epsilon)

def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7):
"""
Convert rotation matrix to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
rotmat (...x3x3 tensor): input batch of rotation matrices.
degrees (bool): if True, returned angles are expressed in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians.
In case of gimbal lock, the third angle is arbitrarily set to 0.
"""
return unitquat_to_euler(convention, roma.rotmat_to_unitquat(rotmat), degrees=degrees, epsilon=epsilon)
2 changes: 1 addition & 1 deletion roma/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,4 @@ def quat_wxyz_to_xyzw(wxyz):
batch of quaternions (...x4 tensor, XYZW convention).
"""
assert wxyz.shape[-1] == 4
return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1)
return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="roma",
version="1.4.5",
version="1.5.0",
author="Romain Brégier",
author_email="[email protected]",
description="A lightweight library to deal with 3D rotations in PyTorch.",
Expand Down
86 changes: 86 additions & 0 deletions test/test_euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# RoMa
# Copyright (c) 2020 NAVER Corp.
# 3-Clause BSD License.
import unittest
import torch
import roma
import numpy as np
from test.utils import is_close
import itertools

device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')

class TestEuler(unittest.TestCase):
def test_euler(self):
batch_shape = torch.Size((3,2))
x = torch.randn(batch_shape)
y = torch.randn(batch_shape)
q = roma.euler_to_unitquat('xy', (x, y))
self.assertTrue(q.shape == batch_shape + (4,))

def test_euler_unitquat_consistency(self):
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')
for degrees in (True, False):
for batch_shape in [tuple(),
torch.Size((30,)),
torch.Size((50,60))]:
for intrinsics in (True, False):
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]:
if intrinsics:
convention = convention.upper()
q = roma.random_unitquat(batch_shape, device=device)
angles = roma.unitquat_to_euler(convention, q, degrees=degrees)
self.assertTrue(len(angles) == 3)
self.assertTrue(all([angle.shape == batch_shape for angle in angles]))
if degrees:
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles]))
else:
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles]))
q1 = roma.euler_to_unitquat(convention, angles, degrees=degrees)
self.assertTrue(torch.all(roma.unitquat_geodesic_distance(q, q1) < 1e-6))

def test_euler_rotvec_consistency(self):
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float64
for degrees in (True, False):
for batch_shape in [tuple(),
torch.Size((30,)),
torch.Size((50,60))]:
for intrinsics in (True, False):
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]:
if intrinsics:
convention = convention.upper()
q = roma.random_rotvec(batch_shape, device=device, dtype=dtype)
angles = roma.rotvec_to_euler(convention, q, degrees=degrees)
self.assertTrue(len(angles) == 3)
self.assertTrue(all([angle.shape == batch_shape for angle in angles]))
if degrees:
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles]))
else:
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles]))
q1 = roma.euler_to_rotvec(convention, angles, degrees=degrees)
self.assertTrue(torch.all(roma.rotvec_geodesic_distance(q, q1) < 1e-6))

def test_euler_rotmat_consistency(self):
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')
for degrees in (True, False):
for batch_shape in [tuple(),
torch.Size((30,)),
torch.Size((50,60))]:
for intrinsics in (True, False):
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]:
if intrinsics:
convention = convention.upper()
q = roma.random_rotmat(batch_shape, device=device)
angles = roma.rotmat_to_euler(convention, q, degrees=degrees)
self.assertTrue(len(angles) == 3)
self.assertTrue(all([angle.shape == batch_shape for angle in angles]))
if degrees:
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles]))
else:
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles]))
q1 = roma.euler_to_rotmat(convention, angles, degrees=degrees)
self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q, q1) < 1e-6))

if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions test/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import numpy as np
from test.utils import is_close

def is_close(A, B, eps1 = 1e-5, eps2 = 1e-5):
return torch.norm(A - B) / (torch.norm(torch.abs(A) + torch.abs(B)) + eps1) < eps2

device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')

class TestMappings(unittest.TestCase):
Expand Down Expand Up @@ -191,6 +188,5 @@ def test_quat_conventions(self):
quat_xyzw_bis = roma.mappings.quat_wxyz_to_xyzw(quat_wxyz)
self.assertTrue(torch.all(quat_xyzw == quat_xyzw_bis))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8c7eaa0

Please sign in to comment.