From 8c7eaa0329c92bcd0a2d4fafd3882ee2ce2e398e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20BR=C3=89GIER?= Date: Fri, 19 Apr 2024 17:42:31 +0200 Subject: [PATCH] Euler angles --- docsource/source/index.rst | 7 +- roma/__init__.py | 1 + roma/euler.py | 229 +++++++++++++++++++++++++++++++++++++ roma/mappings.py | 2 +- setup.py | 2 +- test/test_euler.py | 86 ++++++++++++++ test/test_mappings.py | 4 - 7 files changed, 324 insertions(+), 7 deletions(-) create mode 100644 roma/euler.py create mode 100644 test/test_euler.py diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 57486b7..b21051d 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -195,6 +195,9 @@ Mappings .. automodule:: roma.mappings :members: +.. automodule:: roma.euler + :members: + Utils ---------- .. automodule:: roma.utils @@ -225,12 +228,14 @@ From source repository:: License ======= -*RoMa*, Copyright (c) 2021 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see `license `_). +*RoMa*, Copyright (c) 2020 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see `license `_). Bits of code were adapted from SciPy. Documentation is generated, distributed and displayed with the support of Sphinx and other materials (see `notice `_). Changelog ========== +Version 1.5.0: + - Added Euler angles mappings. Version 1.4.5: - 3-Clause BSD Licensing. Version 1.4.4: diff --git a/roma/__init__.py b/roma/__init__.py index eebbdce..ee21d16 100644 --- a/roma/__init__.py +++ b/roma/__init__.py @@ -4,3 +4,4 @@ from .mappings import * from .utils import * from .transforms import * +from .euler import * diff --git a/roma/euler.py b/roma/euler.py new file mode 100644 index 0000000..b1b2005 --- /dev/null +++ b/roma/euler.py @@ -0,0 +1,229 @@ +# RoMa +# Copyright (c) 2020 NAVER Corp. +# 3-Clause BSD License. +import torch +import roma +import numpy as np +import roma.internal + +def _elementary_basis_index(axis): + """ + Return the index corresponding to a given axis label. + """ + if axis == 'x': + return 0 + elif axis == 'y': + return 1 + elif axis == 'z': + return 2 + else: + raise ValueError("Invalid axis.") + +def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=True, dtype=None, device=None): + """ + Convert Euler angles to unit quaternion representation. + + Args: + convention (string): string defining a sequence of rotation axes ('XYZ' or 'xzx' for example). + The sequence of rotation is expressed either with respect to a global 'extrinsic' coordinate system (in which case axes are denoted in lowercase: 'x', 'y', or 'z'), + or with respect to an 'intrinsic' coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: 'X', 'Y', 'Z'). + Intrinsic and extrinsic conventions cannot be mixed. + angles (list of floats or list of tensors): a list of angles associated to each axis, expressed in radians by default. + degrees (bool): if True, input angles are assumed to be expressed in degrees. + + Returns: + A batch of unit quaternions (...x4 tensor, XYZW convention). + """ + assert len(convention) == len(angles) + + extrinsics = convention.islower() + if extrinsics: + # Cast from intrinsics to extrinsics convention + convention = convention.upper()[::-1] + angles = angles[::-1] + + unitquats = [] + for axis, angle in zip(convention, angles): + angle = torch.as_tensor(angle, device=device, dtype=dtype) + if degrees: + angle = torch.deg2rad(angle) + batch_shape = angle.shape + rotvec = torch.zeros(batch_shape + torch.Size((3,)), device=angle.device, dtype=angle.dtype) + if axis == 'X': + rotvec[...,0] = angle + elif axis == 'Y': + rotvec[...,1] = angle + elif axis == 'Z': + rotvec[...,2] = angle + else: + raise ValueError("Invalid convention (expected format: 'xyz', 'zxz', 'XYZ', etc.).") + q = roma.rotvec_to_unitquat(rotvec) + unitquats.append(q) + return roma.quat_composition(unitquats, normalize=normalize) + +def euler_to_rotvec(convention: str, angles : list, degrees=False, dtype=None, device=None): + """ + Convert Euler angles to rotation vector representation. + + Args: + convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. + angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. + degrees (bool): if True, input angles are assumed to be expressed in degrees. + + Returns: + a batch of rotation vectors (...x3 tensor). + """ + return roma.unitquat_to_rotvec(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device)) + +def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, device=None): + """ + Convert Euler angles to rotation matrix representation. + + Args: + convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. + angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. + degrees (bool): if True, input angles are assumed to be expressed in degrees. + + Returns: + a batch of rotation matrices (...x3x3 tensor). + """ + return roma.unitquat_to_rotmat(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device)) + +def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): + """ + Convert unit quaternion to Euler angles representation. + + Args: + convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. + Consecutive axes should not be identical. + quat (...x4 tensor, XYZW convention): input batch of unit quaternion. + degrees (bool): if True, returned angles are expressed in degrees. + epsilon (float): a small value used to detect degenerate configurations. + + Returns: + A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. + In case of gimbal lock, the third angle is arbitrarily set to 0. + """ + # Code adapted from scipy.spatial.transform.Rotation. + # Reference: https://github.com/scipy/scipy/blob/ac6bcaf00411286271f7cc21e495192c73168ae4/scipy/spatial/transform/_rotation.pyx#L325C12-L325C15 + assert len(convention) == 3 + + pi = np.pi + lamb = np.pi/2 + + extrinsic = convention.islower() + if not extrinsic: + convention = convention.lower()[::-1] + + quat, batch_shape = roma.internal.flatten_batch_dims(quat, end_dim=-2) + N = quat.shape[0] + + i = _elementary_basis_index(convention[0]) + j = _elementary_basis_index(convention[1]) + k = _elementary_basis_index(convention[2]) + assert i != j and j != k, "Consecutive axes should not be identical." + + symmetric = (i == k) + + if symmetric: + # Get third axis + k = 3 - i - j + + # Step 0 + # Check if permutation is even (+1) or odd (-1) + sign = (i - j) * (j - k) * (k - i) // 2 + + # Step 1 + # Permutate quaternion elements + if symmetric: + a = quat[:,3] + b = quat[:,i] + c = quat[:,j] + d = quat[:,k] * sign + else: + a = quat[:,3] - quat[:,j] + b = quat[:,i] + quat[:,k] * sign + c = quat[:,j] + quat[:,3] + d = quat[:,k] * sign - quat[:,i] + + + # intrinsic/extrinsic conversion helpers + if extrinsic: + angle_first = 0 + angle_third = 2 + else: + angle_first = 2 + angle_third = 0 + + # Step 2 + # Compute second angle... + angles = [torch.empty(N, device=quat.device, dtype=quat.dtype) for _ in range(3)] + + angles[1] = 2 * torch.atan2(torch.hypot(c, d), torch.hypot(a, b)) + + # ... and check if equal to is 0 or pi, causing a singularity + case1 = torch.abs(angles[1]) <= epsilon + case2 = torch.abs(angles[1] - pi) <= epsilon + case1or2 = torch.logical_or(case1, case2) + # Step 3 + # compute first and third angles, according to case + half_sum = torch.atan2(b, a) + half_diff = torch.atan2(d, c) + + # no singularities + angles[angle_first] = half_sum - half_diff + angles[angle_third] = half_sum + half_diff + + # any degenerate case + angles[2][case1or2] = 0 + angles[0][case1] = 2 * half_sum[case1] + angles[0][case2] = 2 * (-1 if extrinsic else 1) * half_diff[case2] + + # for Tait-Bryan/asymmetric sequences + if not symmetric: + angles[angle_third] *= sign + angles[1] -= lamb + + for idx in range(3): + foo = angles[idx] + foo[foo < -pi] += 2 * pi + foo[foo > pi] -= 2 * pi + if degrees: + foo = torch.rad2deg(foo) + angles[idx] = roma.internal.unflatten_batch_dims(foo, batch_shape) + + return angles + +def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7): + """ + Convert rotation vector to Euler angles representation. + + Args: + convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. + Consecutive axes should not be identical. + rotvec (...x3 tensor): input batch of rotation vectors. + degrees (bool): if True, returned angles are expressed in degrees. + epsilon (float): a small value used to detect degenerate configurations. + + Returns: + A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. + In case of gimbal lock, the third angle is arbitrarily set to 0. + """ + return unitquat_to_euler(convention, roma.rotvec_to_unitquat(rotvec), degrees=degrees, epsilon=epsilon) + +def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7): + """ + Convert rotation matrix to Euler angles representation. + + Args: + convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. + Consecutive axes should not be identical. + rotmat (...x3x3 tensor): input batch of rotation matrices. + degrees (bool): if True, returned angles are expressed in degrees. + epsilon (float): a small value used to detect degenerate configurations. + + Returns: + A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. + In case of gimbal lock, the third angle is arbitrarily set to 0. + """ + return unitquat_to_euler(convention, roma.rotmat_to_unitquat(rotmat), degrees=degrees, epsilon=epsilon) \ No newline at end of file diff --git a/roma/mappings.py b/roma/mappings.py index 1b63321..c9fbeeb 100644 --- a/roma/mappings.py +++ b/roma/mappings.py @@ -450,4 +450,4 @@ def quat_wxyz_to_xyzw(wxyz): batch of quaternions (...x4 tensor, XYZW convention). """ assert wxyz.shape[-1] == 4 - return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1) + return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1) \ No newline at end of file diff --git a/setup.py b/setup.py index 3301e4c..4ce2312 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="roma", - version="1.4.5", + version="1.5.0", author="Romain Brégier", author_email="romain.bregier@naverlabs.com", description="A lightweight library to deal with 3D rotations in PyTorch.", diff --git a/test/test_euler.py b/test/test_euler.py new file mode 100644 index 0000000..a99c8eb --- /dev/null +++ b/test/test_euler.py @@ -0,0 +1,86 @@ +# RoMa +# Copyright (c) 2020 NAVER Corp. +# 3-Clause BSD License. +import unittest +import torch +import roma +import numpy as np +from test.utils import is_close +import itertools + +device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + +class TestEuler(unittest.TestCase): + def test_euler(self): + batch_shape = torch.Size((3,2)) + x = torch.randn(batch_shape) + y = torch.randn(batch_shape) + q = roma.euler_to_unitquat('xy', (x, y)) + self.assertTrue(q.shape == batch_shape + (4,)) + + def test_euler_unitquat_consistency(self): + device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + for degrees in (True, False): + for batch_shape in [tuple(), + torch.Size((30,)), + torch.Size((50,60))]: + for intrinsics in (True, False): + for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: + if intrinsics: + convention = convention.upper() + q = roma.random_unitquat(batch_shape, device=device) + angles = roma.unitquat_to_euler(convention, q, degrees=degrees) + self.assertTrue(len(angles) == 3) + self.assertTrue(all([angle.shape == batch_shape for angle in angles])) + if degrees: + self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) + else: + self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) + q1 = roma.euler_to_unitquat(convention, angles, degrees=degrees) + self.assertTrue(torch.all(roma.unitquat_geodesic_distance(q, q1) < 1e-6)) + + def test_euler_rotvec_consistency(self): + device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float64 + for degrees in (True, False): + for batch_shape in [tuple(), + torch.Size((30,)), + torch.Size((50,60))]: + for intrinsics in (True, False): + for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: + if intrinsics: + convention = convention.upper() + q = roma.random_rotvec(batch_shape, device=device, dtype=dtype) + angles = roma.rotvec_to_euler(convention, q, degrees=degrees) + self.assertTrue(len(angles) == 3) + self.assertTrue(all([angle.shape == batch_shape for angle in angles])) + if degrees: + self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) + else: + self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) + q1 = roma.euler_to_rotvec(convention, angles, degrees=degrees) + self.assertTrue(torch.all(roma.rotvec_geodesic_distance(q, q1) < 1e-6)) + + def test_euler_rotmat_consistency(self): + device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + for degrees in (True, False): + for batch_shape in [tuple(), + torch.Size((30,)), + torch.Size((50,60))]: + for intrinsics in (True, False): + for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: + if intrinsics: + convention = convention.upper() + q = roma.random_rotmat(batch_shape, device=device) + angles = roma.rotmat_to_euler(convention, q, degrees=degrees) + self.assertTrue(len(angles) == 3) + self.assertTrue(all([angle.shape == batch_shape for angle in angles])) + if degrees: + self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) + else: + self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) + q1 = roma.euler_to_rotmat(convention, angles, degrees=degrees) + self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q, q1) < 1e-6)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_mappings.py b/test/test_mappings.py index 3df6536..368efb0 100644 --- a/test/test_mappings.py +++ b/test/test_mappings.py @@ -7,9 +7,6 @@ import numpy as np from test.utils import is_close -def is_close(A, B, eps1 = 1e-5, eps2 = 1e-5): - return torch.norm(A - B) / (torch.norm(torch.abs(A) + torch.abs(B)) + eps1) < eps2 - device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') class TestMappings(unittest.TestCase): @@ -191,6 +188,5 @@ def test_quat_conventions(self): quat_xyzw_bis = roma.mappings.quat_wxyz_to_xyzw(quat_wxyz) self.assertTrue(torch.all(quat_xyzw == quat_xyzw_bis)) - if __name__ == "__main__": unittest.main() \ No newline at end of file