diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index 9a34ffd..2d9a85e 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -12,9 +12,9 @@ def get_data(self, size=1024, ndim=2): data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim) return data - def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): + def get_tiled_data(self, tile_shape, size=1024, ndim=2): data = self.get_data(size=size, ndim=ndim) - data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) + original_data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) # Create tiles out of the data for testing label stitching. # Ensure offset for objects per tile to get individual ids per object per tile. @@ -36,7 +36,7 @@ def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): labels[bb] = tile - return labels, data # returns the stitched labels and original labels + return labels, original_data # returns the stitched labels and original labels def test_stitch_segmentation(self): from elf.segmentation.stitching import stitch_segmentation @@ -77,7 +77,7 @@ def test_stitch_tiled_segmentation(self): tile_shapes = [(224, 224), (256, 256), (512, 512)] for tile_shape in tile_shapes: # Get the tiled segmentation with unmerged instances at tile interfaces. - labels, original_labels = self.get_tiled_data() + labels, original_labels = self.get_tiled_data(tile_shape=tile_shape, size=1000) stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape) self.assertEqual(labels.shape, stitched_labels.shape) # self.assertEqual(len(np.unique(original_labels)), len(np.unique(stitched_labels)))