Skip to content

Commit

Permalink
Merge pull request #721 from ANTsX/fix_slicing
Browse files Browse the repository at this point in the history
BUG: Fix slicing. More consistent with numpy
  • Loading branch information
cookpa authored Oct 22, 2024
2 parents 88437eb + b5380b4 commit 6074d85
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 49 deletions.
30 changes: 22 additions & 8 deletions ants/core/ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,23 +507,37 @@ def __getitem__(self, idx):
raise ValueError('images do not occupy same physical space')
return self.numpy().__getitem__(idx.numpy().astype('bool'))

ndim = len(idx)
# convert idx to tuple if it is not, eg im[10] or im[10:20]
if not isinstance(idx, tuple):
idx = (idx,)

ndim = len(self.shape)

if len(idx) > ndim:
raise ValueError('Too many indices for image')
if len(idx) < ndim:
# If not all dimensions are indexed, assume the rest are full slices
# eg im[10] -> im[10, :, :]
idx = idx + (slice(None),) * (ndim - len(idx))

sizes = list(self.shape)
starts = [0] * ndim

stops = list(self.shape)
for i in range(ndim):
ti = idx[i]
if isinstance(ti, slice):
if ti.start:
starts[i] = ti.start
if ti.stop:
sizes[i] = ti.stop - starts[i]
else:
sizes[i] = self.shape[i] - starts[i]
if ti.stop < 0:
stops[i] = self.shape[i] + ti.stop
else:
stops[i] = ti.stop

sizes[i] = stops[i] - starts[i]

if ti.stop and ti.start:
if ti.stop < ti.start:
raise Exception('Reverse indexing is not supported.')
if stops[i] < starts[i]:
raise ValueError('Reverse indexing is not supported.')

elif isinstance(ti, int):
starts[i] = ti
Expand Down
117 changes: 76 additions & 41 deletions tests/test_core_ants_image_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@


class TestClass_AntsImageIndexing(unittest.TestCase):

def setUp(self):
pass
def tearDown(self):
pass

def test_pixeltype_2d(self):
img = ants.image_read(ants.get_data('r16'))
for ptype in ['unsigned char', 'unsigned int', 'float', 'double']:
img = img.clone(ptype)
self.assertEqual(img.pixeltype, ptype)
img2 = img[:10,:10]
self.assertEqual(img2.pixeltype, ptype)

def test_pixeltype_3d(self):
img = ants.image_read(ants.get_data('mni'))
for ptype in ['unsigned char', 'unsigned int', 'float', 'double']:
Expand All @@ -43,41 +43,41 @@ def test_pixeltype_3d(self):
self.assertEqual(img2.pixeltype, ptype)
img3 = img[:10,:10,10]
self.assertEqual(img3.pixeltype, ptype)

def test_2d(self):
img = ants.image_read(ants.get_data('r16'))

img2 = img[:10,:10]
self.assertEqual(img2.dimension, 2)
img2 = img[:5,:5]
self.assertEqual(img2.dimension, 2)

img2 = img[1:20,1:10]
self.assertEqual(img2.dimension, 2)
img2 = img[:5,:5]
self.assertEqual(img2.dimension, 2)
img2 = img[:5,4:5]
self.assertEqual(img2.dimension, 2)

img2 = img[5:5,5:5]

# down to 1d
arr = img[10,:]
self.assertTrue(isinstance(arr, np.ndarray))

# single value
arr = img[10,10]

def test_2d_image_index(self):
img = ants.image_read(ants.get_data('r16'))
idx = img > 200

# acts like a mask
img2 = img[idx]

def test_3d(self):
img = ants.image_read(ants.get_data('mni'))

img2 = img[:10,:10,:10]
self.assertEqual(img2.dimension, 3)
img2 = img[:5,:5,:5]
Expand All @@ -86,7 +86,7 @@ def test_3d(self):
self.assertEqual(img2.dimension, 3)
img2 = img[:5,:5,:5]
self.assertEqual(img2.dimension, 3)

# down to 2d
img2 = img[10,:,:]
self.assertEqual(img2.dimension, 2)
Expand All @@ -100,110 +100,145 @@ def test_3d(self):
self.assertEqual(img2.dimension, 2)
img2 = img[2:20,3:30,10]
self.assertEqual(img2.dimension, 2)

# down to 1d
arr = img[10,:,10]
self.assertTrue(isinstance(arr, np.ndarray))
arr = img[10,:,5]
self.assertTrue(isinstance(arr, np.ndarray))

# single value
arr = img[10,10,10]

def test_double_indexing(self):
img = ants.image_read(ants.get_data('mni'))
img2 = img[20:,:,:]
self.assertEqual(img2.shape, (162,218,182))

img3 = img[0,:,:]
self.assertEqual(img3.shape, (218,182))

def test_reverse_error(self):
img = ants.image_read(ants.get_data('mni'))
with self.assertRaises(Exception):
img2 = img[20:10,:,:]

def test_2d_vector(self):
img = ants.image_read(ants.get_data('r16'))
img2 = img[:10,:10]

img_v = ants.merge_channels([img])
img_v2 = img_v[:10,:10]

self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[0]))

def test_2d_vector_multi(self):
img = ants.image_read(ants.get_data('r16'))
img2 = img[:10,:10]

img_v = ants.merge_channels([img,img,img])
img_v2 = img_v[:10,:10]

self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[0]))
self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[1]))
self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[2]))

def test_setting_3d(self):
img = ants.image_read(ants.get_data('mni'))
img2d = img[100,:,:]

# setting a sub-image with an image
img2 = img + 10
img2[100,:,:] = img2d

self.assertFalse(ants.allclose(img, img2))
self.assertTrue(ants.allclose(img2d, img2[100,:,:]))

# setting a sub-image with an array
img2 = img + 10
img2[100,:,:] = img2d.numpy()

self.assertFalse(ants.allclose(img, img2))
self.assertTrue(ants.allclose(img2d, img2[100,:,:]))

def test_setting_2d(self):
img = ants.image_read(ants.get_data('r16'))
img2d = img[100,:]

# setting a sub-image with an image
img2 = img + 10
img2[100,:] = img2d

self.assertFalse(ants.allclose(img, img2))
self.assertTrue(np.allclose(img2d, img2[100,:]))


def test_setting_2d_sub_image(self):
img = ants.image_read(ants.get_data('r16'))
img2d = img[10:30,10:30]

# setting a sub-image with an image
img2 = img + 10
img2[10:30,10:30] = img2d

self.assertFalse(ants.allclose(img, img2))
self.assertTrue(ants.allclose(img2d, img2[10:30,10:30]))

# setting a sub-image with an array
img2 = img + 10
img2[10:30,10:30] = img2d.numpy()

self.assertFalse(ants.allclose(img, img2))
self.assertTrue(ants.allclose(img2d, img2[10:30,10:30]))

def test_setting_correctness(self):

img = ants.image_read(ants.get_data('r16')) * 0
self.assertEqual(img.sum(), 0)

img2 = img[10:30,10:30]
img2 = img2 + 10
self.assertEqual(img2.mean(), 10)

img[:20,:20] = img2
self.assertEqual(img.sum(), img2.sum())
self.assertEqual(img.numpy()[:20,:20].sum(), img2.sum())
self.assertNotEqual(img.numpy()[10:30,10:30].sum(), img2.sum())


def test_slicing_3d(self):
img = ants.image_read(ants.get_data('mni'))
img2 = img[:10,:10,:10]
img3 = img[10:20,10:20,10:20]

self.assertTrue(ants.allclose(img2, img3))

img_np = img.numpy()

self.assertTrue(np.allclose(img2.numpy(), img_np[:10,:10,:10]))
self.assertTrue(np.allclose(img[20].numpy(), img_np[20]))
self.assertTrue(np.allclose(img[:,20:40].numpy(), img_np[:,20:40]))
self.assertTrue(np.allclose(img[:,:,20:-2].numpy(), img_np[:,:,20:-2]))
self.assertTrue(np.allclose(img[0:-1,].numpy(), img_np[0:-1,]))
self.assertTrue(np.allclose(img[100,10:100,0:-1].numpy(), img_np[100,10:100,0:-1]))
self.assertTrue(np.allclose(img[:,10:,30:].numpy(), img_np[:,10:,30:]))
# if the slice returns 1D, it should be a numpy array already
self.assertTrue(np.allclose(img[100:-1,30,40], img_np[100:-1,30,40]))

def test_slicing_2d(self):
img = ants.image_read(ants.get_data('r16'))

img2 = img[:10,:10]

img_np = img.numpy()

self.assertTrue(np.allclose(img2.numpy(), img_np[:10,:10]))
self.assertTrue(np.allclose(img[:,20:40].numpy(), img_np[:,20:40]))
self.assertTrue(np.allclose(img[0:-1,].numpy(), img_np[0:-1,]))
self.assertTrue(np.allclose(img[50:,10:-3].numpy(), img_np[50:,10:-3]))
# if the slice returns 1D, it should be a numpy array already
self.assertTrue(np.allclose(img[20], img_np[20]))
self.assertTrue(np.allclose(img[100:-1,30], img_np[100:-1,30]))

if __name__ == '__main__':
run_tests()

0 comments on commit 6074d85

Please sign in to comment.