diff --git a/test/test_euler.py b/test/test_euler.py index a99c8eb..d29214e 100644 --- a/test/test_euler.py +++ b/test/test_euler.py @@ -80,7 +80,16 @@ def test_euler_rotmat_consistency(self): else: self.assertTrue(all([torch.all(angle > -np.pi) and torch.all(angle <= np.pi) for angle in angles])) q1 = roma.euler_to_rotmat(convention, angles, degrees=degrees) - self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q, q1) < 1e-6)) + self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q, q1) < 1e-6)) + + def test_euler_backward(self): + for intrinsics in (True, False): + for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: + if intrinsics: + convention = convention.upper() + rotvec = torch.randn(3, requires_grad=True) + angles = roma.rotvec_to_euler('xyz', rotvec) + sum(angles).backward() if __name__ == "__main__": unittest.main() \ No newline at end of file