-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Romain BRÉGIER
committed
Apr 19, 2024
1 parent
c406cb5
commit 8c7eaa0
Showing
7 changed files
with
324 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
from .mappings import * | ||
from .utils import * | ||
from .transforms import * | ||
from .euler import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
# RoMa | ||
# Copyright (c) 2020 NAVER Corp. | ||
# 3-Clause BSD License. | ||
import torch | ||
import roma | ||
import numpy as np | ||
import roma.internal | ||
|
||
def _elementary_basis_index(axis): | ||
""" | ||
Return the index corresponding to a given axis label. | ||
""" | ||
if axis == 'x': | ||
return 0 | ||
elif axis == 'y': | ||
return 1 | ||
elif axis == 'z': | ||
return 2 | ||
else: | ||
raise ValueError("Invalid axis.") | ||
|
||
def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=True, dtype=None, device=None): | ||
""" | ||
Convert Euler angles to unit quaternion representation. | ||
Args: | ||
convention (string): string defining a sequence of rotation axes ('XYZ' or 'xzx' for example). | ||
The sequence of rotation is expressed either with respect to a global 'extrinsic' coordinate system (in which case axes are denoted in lowercase: 'x', 'y', or 'z'), | ||
or with respect to an 'intrinsic' coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: 'X', 'Y', 'Z'). | ||
Intrinsic and extrinsic conventions cannot be mixed. | ||
angles (list of floats or list of tensors): a list of angles associated to each axis, expressed in radians by default. | ||
degrees (bool): if True, input angles are assumed to be expressed in degrees. | ||
Returns: | ||
A batch of unit quaternions (...x4 tensor, XYZW convention). | ||
""" | ||
assert len(convention) == len(angles) | ||
|
||
extrinsics = convention.islower() | ||
if extrinsics: | ||
# Cast from intrinsics to extrinsics convention | ||
convention = convention.upper()[::-1] | ||
angles = angles[::-1] | ||
|
||
unitquats = [] | ||
for axis, angle in zip(convention, angles): | ||
angle = torch.as_tensor(angle, device=device, dtype=dtype) | ||
if degrees: | ||
angle = torch.deg2rad(angle) | ||
batch_shape = angle.shape | ||
rotvec = torch.zeros(batch_shape + torch.Size((3,)), device=angle.device, dtype=angle.dtype) | ||
if axis == 'X': | ||
rotvec[...,0] = angle | ||
elif axis == 'Y': | ||
rotvec[...,1] = angle | ||
elif axis == 'Z': | ||
rotvec[...,2] = angle | ||
else: | ||
raise ValueError("Invalid convention (expected format: 'xyz', 'zxz', 'XYZ', etc.).") | ||
q = roma.rotvec_to_unitquat(rotvec) | ||
unitquats.append(q) | ||
return roma.quat_composition(unitquats, normalize=normalize) | ||
|
||
def euler_to_rotvec(convention: str, angles : list, degrees=False, dtype=None, device=None): | ||
""" | ||
Convert Euler angles to rotation vector representation. | ||
Args: | ||
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. | ||
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. | ||
degrees (bool): if True, input angles are assumed to be expressed in degrees. | ||
Returns: | ||
a batch of rotation vectors (...x3 tensor). | ||
""" | ||
return roma.unitquat_to_rotvec(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device)) | ||
|
||
def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, device=None): | ||
""" | ||
Convert Euler angles to rotation matrix representation. | ||
Args: | ||
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. | ||
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. | ||
degrees (bool): if True, input angles are assumed to be expressed in degrees. | ||
Returns: | ||
a batch of rotation matrices (...x3x3 tensor). | ||
""" | ||
return roma.unitquat_to_rotmat(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device)) | ||
|
||
def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): | ||
""" | ||
Convert unit quaternion to Euler angles representation. | ||
Args: | ||
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. | ||
Consecutive axes should not be identical. | ||
quat (...x4 tensor, XYZW convention): input batch of unit quaternion. | ||
degrees (bool): if True, returned angles are expressed in degrees. | ||
epsilon (float): a small value used to detect degenerate configurations. | ||
Returns: | ||
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. | ||
In case of gimbal lock, the third angle is arbitrarily set to 0. | ||
""" | ||
# Code adapted from scipy.spatial.transform.Rotation. | ||
# Reference: https://github.com/scipy/scipy/blob/ac6bcaf00411286271f7cc21e495192c73168ae4/scipy/spatial/transform/_rotation.pyx#L325C12-L325C15 | ||
assert len(convention) == 3 | ||
|
||
pi = np.pi | ||
lamb = np.pi/2 | ||
|
||
extrinsic = convention.islower() | ||
if not extrinsic: | ||
convention = convention.lower()[::-1] | ||
|
||
quat, batch_shape = roma.internal.flatten_batch_dims(quat, end_dim=-2) | ||
N = quat.shape[0] | ||
|
||
i = _elementary_basis_index(convention[0]) | ||
j = _elementary_basis_index(convention[1]) | ||
k = _elementary_basis_index(convention[2]) | ||
assert i != j and j != k, "Consecutive axes should not be identical." | ||
|
||
symmetric = (i == k) | ||
|
||
if symmetric: | ||
# Get third axis | ||
k = 3 - i - j | ||
|
||
# Step 0 | ||
# Check if permutation is even (+1) or odd (-1) | ||
sign = (i - j) * (j - k) * (k - i) // 2 | ||
|
||
# Step 1 | ||
# Permutate quaternion elements | ||
if symmetric: | ||
a = quat[:,3] | ||
b = quat[:,i] | ||
c = quat[:,j] | ||
d = quat[:,k] * sign | ||
else: | ||
a = quat[:,3] - quat[:,j] | ||
b = quat[:,i] + quat[:,k] * sign | ||
c = quat[:,j] + quat[:,3] | ||
d = quat[:,k] * sign - quat[:,i] | ||
|
||
|
||
# intrinsic/extrinsic conversion helpers | ||
if extrinsic: | ||
angle_first = 0 | ||
angle_third = 2 | ||
else: | ||
angle_first = 2 | ||
angle_third = 0 | ||
|
||
# Step 2 | ||
# Compute second angle... | ||
angles = [torch.empty(N, device=quat.device, dtype=quat.dtype) for _ in range(3)] | ||
|
||
angles[1] = 2 * torch.atan2(torch.hypot(c, d), torch.hypot(a, b)) | ||
|
||
# ... and check if equal to is 0 or pi, causing a singularity | ||
case1 = torch.abs(angles[1]) <= epsilon | ||
case2 = torch.abs(angles[1] - pi) <= epsilon | ||
case1or2 = torch.logical_or(case1, case2) | ||
# Step 3 | ||
# compute first and third angles, according to case | ||
half_sum = torch.atan2(b, a) | ||
half_diff = torch.atan2(d, c) | ||
|
||
# no singularities | ||
angles[angle_first] = half_sum - half_diff | ||
angles[angle_third] = half_sum + half_diff | ||
|
||
# any degenerate case | ||
angles[2][case1or2] = 0 | ||
angles[0][case1] = 2 * half_sum[case1] | ||
angles[0][case2] = 2 * (-1 if extrinsic else 1) * half_diff[case2] | ||
|
||
# for Tait-Bryan/asymmetric sequences | ||
if not symmetric: | ||
angles[angle_third] *= sign | ||
angles[1] -= lamb | ||
|
||
for idx in range(3): | ||
foo = angles[idx] | ||
foo[foo < -pi] += 2 * pi | ||
foo[foo > pi] -= 2 * pi | ||
if degrees: | ||
foo = torch.rad2deg(foo) | ||
angles[idx] = roma.internal.unflatten_batch_dims(foo, batch_shape) | ||
|
||
return angles | ||
|
||
def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7): | ||
""" | ||
Convert rotation vector to Euler angles representation. | ||
Args: | ||
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. | ||
Consecutive axes should not be identical. | ||
rotvec (...x3 tensor): input batch of rotation vectors. | ||
degrees (bool): if True, returned angles are expressed in degrees. | ||
epsilon (float): a small value used to detect degenerate configurations. | ||
Returns: | ||
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. | ||
In case of gimbal lock, the third angle is arbitrarily set to 0. | ||
""" | ||
return unitquat_to_euler(convention, roma.rotvec_to_unitquat(rotvec), degrees=degrees, epsilon=epsilon) | ||
|
||
def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7): | ||
""" | ||
Convert rotation matrix to Euler angles representation. | ||
Args: | ||
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. | ||
Consecutive axes should not be identical. | ||
rotmat (...x3x3 tensor): input batch of rotation matrices. | ||
degrees (bool): if True, returned angles are expressed in degrees. | ||
epsilon (float): a small value used to detect degenerate configurations. | ||
Returns: | ||
A list of 3 tensors corresponding to each Euler angle, expressed by default in radians. | ||
In case of gimbal lock, the third angle is arbitrarily set to 0. | ||
""" | ||
return unitquat_to_euler(convention, roma.rotmat_to_unitquat(rotmat), degrees=degrees, epsilon=epsilon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
|
||
setuptools.setup( | ||
name="roma", | ||
version="1.4.5", | ||
version="1.5.0", | ||
author="Romain Brégier", | ||
author_email="[email protected]", | ||
description="A lightweight library to deal with 3D rotations in PyTorch.", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# RoMa | ||
# Copyright (c) 2020 NAVER Corp. | ||
# 3-Clause BSD License. | ||
import unittest | ||
import torch | ||
import roma | ||
import numpy as np | ||
from test.utils import is_close | ||
import itertools | ||
|
||
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
class TestEuler(unittest.TestCase): | ||
def test_euler(self): | ||
batch_shape = torch.Size((3,2)) | ||
x = torch.randn(batch_shape) | ||
y = torch.randn(batch_shape) | ||
q = roma.euler_to_unitquat('xy', (x, y)) | ||
self.assertTrue(q.shape == batch_shape + (4,)) | ||
|
||
def test_euler_unitquat_consistency(self): | ||
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') | ||
for degrees in (True, False): | ||
for batch_shape in [tuple(), | ||
torch.Size((30,)), | ||
torch.Size((50,60))]: | ||
for intrinsics in (True, False): | ||
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: | ||
if intrinsics: | ||
convention = convention.upper() | ||
q = roma.random_unitquat(batch_shape, device=device) | ||
angles = roma.unitquat_to_euler(convention, q, degrees=degrees) | ||
self.assertTrue(len(angles) == 3) | ||
self.assertTrue(all([angle.shape == batch_shape for angle in angles])) | ||
if degrees: | ||
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) | ||
else: | ||
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) | ||
q1 = roma.euler_to_unitquat(convention, angles, degrees=degrees) | ||
self.assertTrue(torch.all(roma.unitquat_geodesic_distance(q, q1) < 1e-6)) | ||
|
||
def test_euler_rotvec_consistency(self): | ||
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') | ||
dtype = torch.float64 | ||
for degrees in (True, False): | ||
for batch_shape in [tuple(), | ||
torch.Size((30,)), | ||
torch.Size((50,60))]: | ||
for intrinsics in (True, False): | ||
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: | ||
if intrinsics: | ||
convention = convention.upper() | ||
q = roma.random_rotvec(batch_shape, device=device, dtype=dtype) | ||
angles = roma.rotvec_to_euler(convention, q, degrees=degrees) | ||
self.assertTrue(len(angles) == 3) | ||
self.assertTrue(all([angle.shape == batch_shape for angle in angles])) | ||
if degrees: | ||
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) | ||
else: | ||
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) | ||
q1 = roma.euler_to_rotvec(convention, angles, degrees=degrees) | ||
self.assertTrue(torch.all(roma.rotvec_geodesic_distance(q, q1) < 1e-6)) | ||
|
||
def test_euler_rotmat_consistency(self): | ||
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') | ||
for degrees in (True, False): | ||
for batch_shape in [tuple(), | ||
torch.Size((30,)), | ||
torch.Size((50,60))]: | ||
for intrinsics in (True, False): | ||
for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: | ||
if intrinsics: | ||
convention = convention.upper() | ||
q = roma.random_rotmat(batch_shape, device=device) | ||
angles = roma.rotmat_to_euler(convention, q, degrees=degrees) | ||
self.assertTrue(len(angles) == 3) | ||
self.assertTrue(all([angle.shape == batch_shape for angle in angles])) | ||
if degrees: | ||
self.assertTrue(all([torch.all(angle > -180.) and torch.all(angle <= 180) for angle in angles])) | ||
else: | ||
self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) | ||
q1 = roma.euler_to_rotmat(convention, angles, degrees=degrees) | ||
self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q, q1) < 1e-6)) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters