Skip to content

Commit

Permalink
Fix issues in parallel.label caused by boost union find
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 29, 2024
1 parent 2eda443 commit 7e4380b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 66 deletions.
146 changes: 80 additions & 66 deletions elf/parallel/label.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits
from typing import Optional, Tuple

import multiprocessing
# would be nice to use dask, so that we can also run this on the cluster
Expand All @@ -14,32 +15,33 @@
from .common import get_blocking

import numpy as np
from numpy.typing import ArrayLike


def cc_blocks(data, out, mask, blocking, with_background,
n_threads, verbose):
def cc_blocks(data, out, mask, blocking, with_background, n_threads, verbose):
"""@private
"""
n_blocks = blocking.numberOfBlocks

# compute the connected component for one block
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
# Compute the connected component for one block.
@threadpool_limits.wrap(limits=1) # Restrict the numpy threadpool to 1 to avoid oversubscription.
def _cc_block(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

# check if we have a mask and if we do if we
# have pixels in the mask
# Check if we have a mask and if we do, if we have pixels in the mask.
if mask is not None:
m = mask[bb].astype("bool")
if m.sum() == 0:
return 0

# load the data from this block
# Load the data from this block.
d = data[bb].copy()

# determine the background value
# Determine the background value.
bg_val = 0 if with_background else int(d.max() + 1)

# set mask to background value
# Set mask to background value.
if mask is not None:
d[~m] = bg_val

Expand All @@ -48,7 +50,7 @@ def _cc_block(block_id):
out[bb] = d
return int(d.max())

# compute connected components for all blocks in parallel
# Compute connected components for all blocks in parallel.
with futures.ThreadPoolExecutor(n_threads) as tp:
block_max_labels = list(tqdm(
tp.map(_cc_block, range(n_blocks)), total=n_blocks, desc="Label all sub-blocks", disable=not verbose
Expand All @@ -57,72 +59,71 @@ def _cc_block(block_id):
return out, block_max_labels


def merge_blocks(data, out, mask, offsets,
blocking, max_id, with_background,
n_threads, verbose):
def merge_blocks(data, out, mask, offsets, blocking, max_id, with_background, n_threads, verbose):
"""@private
"""
n_blocks = blocking.numberOfBlocks
ndim = out.ndim

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
@threadpool_limits.wrap(limits=1) # Restrict the numpy threadpool to 1 to avoid oversubscription.
def _merge_block_faces(block_id):
block = blocking.getBlock(block_id)
offset_block = offsets[block_id]

merge_labels = []
# for each axis, load the face with the lower block neighbor and compute the merge labels
# For each axis, load the face with the lower block neighbor and compute the merge labels.
for axis in range(ndim):
ngb_id = blocking.getNeighborId(block_id, axis, lower=True)
if ngb_id == -1:
continue
ngb_block = blocking.getBlock(ngb_id)

# make the bounding box for both faces and load the segmentation for it
# Make the bounding box for both faces and load the segmentation for it.
face = tuple(slice(beg, end) if d != axis else slice(beg, beg + 1)
for d, (beg, end) in enumerate(zip(block.begin, block.end)))
ngb_face = tuple(slice(beg, end) if d != axis else slice(end - 1, end)
for d, (beg, end) in enumerate(zip(ngb_block.begin,
ngb_block.end)))

# load and combine the mask for bot faces
# Load and combine the mask for bot faces.
if mask is not None:
m = np.logical_and(out[face], out[ngb_face])
if m.sum() == 0:
continue

# load the initial labels for both faces
# Load the initial labels for both faces.
d, d_ngb = data[face], data[ngb_face]
assert d.shape == d_ngb.shape

# load the intermediate result for both faces
# Load the intermediate result for both faces.
o, o_ngb = out[face], out[ngb_face]
assert o.shape == o_ngb.shape == d.shape

# allocate full mask if we don't have a mask dataset
# Allocate full mask if we don't have a mask dataset.
if mask is None:
m = np.ones_like(d, dtype="bool")

# mask zero label if we have background
# Mask zero label if we have background.
if with_background:
m[d == 0] = 0
m[d_ngb == 0] = 0

# mask pixels of the face where d != d_ngb, these should not be merged!
# Mask pixels of the face where d != d_ngb, these should not be merged.
m[d != d_ngb] = 0

# is there anything left to merge?
# Is there anything left to merge?
if m.sum() == 0:
continue
offset_ngb = offsets[ngb_id]

# apply the mask to the labels
# Apply the mask to the labels.
o, o_ngb = o[m], o_ngb[m]

# apply the offsets
# Apply the offsets.
o += offset_block
o_ngb += offset_ngb

# compute the merge labels for this face
# by concatenation and unique
# Compute the merge labels for this face by concatenation and unique.
to_merge = np.concatenate([o[:, None], o_ngb[:, None]], axis=1)
to_merge = np.unique(to_merge, axis=0)
if to_merge.size > 0:
Expand All @@ -133,50 +134,52 @@ def _merge_block_faces(block_id):
else:
return None

# compute the merge ids across all block faces
# Compute the merge ids across all block faces.
with futures.ThreadPoolExecutor(n_threads) as tp:
merge_labels = list(tqdm(
tp.map(_merge_block_faces, range(n_blocks)), total=n_blocks,
desc="Merge labels across block faces", disable=not verbose
))

n_elements = max_id + 1
merge_labels = [res for res in merge_labels if res is not None]
if len(merge_labels) == 0:
return np.arange(max_id + 1, dtype=out.dtype)
return np.arange(n_elements, dtype=out.dtype)

merge_labels = np.concatenate(merge_labels, axis=0)
# merge labels via union find
old_labels = np.arange(max_id + 1, dtype=out.dtype)
ufd = nufd.boost_ufd(old_labels)

# Merge labels via union find.
ufd = nufd.ufd(n_elements)
ufd.merge(merge_labels)

# get the new labels from the ufd
# Get the new labels from the ufd.
old_labels = np.arange(n_elements, dtype=out.dtype)
new_labels = ufd.find(old_labels)
if with_background:
assert new_labels[0] == 0
# relabel the new labels consecutively and return them
# Relabel the new labels consecutively and return them.
return relabel_sequential(new_labels)[0]


def write_mapping(out, mask, offsets, mapping,
with_background, blocking,
n_threads, verbose):
def write_mapping(out, mask, offsets, mapping, with_background, blocking, n_threads, verbose):
"""@private
"""
n_blocks = blocking.numberOfBlocks

# compute the connected component for one block
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
# Compute the connected component for one block.
@threadpool_limits.wrap(limits=1) # Restrict the numpy threadpool to 1 to avoid oversubscription.
def _write_block(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

# check if we have a mask and if we do if we
# have pixels in the mask
# Check if we have a mask and if we do, if we have pixels in the mask.
if mask is not None:
m = mask[bb].astype("bool")
if m.sum() == 0:
return None
offset = offsets[block_id]

# load the data from this block
# Load the data from this block.
d = out[bb]
if mask is None:
if with_background:
Expand All @@ -193,7 +196,7 @@ def _write_block(block_id):

out[bb] = d

# compute connected components for all blocks in parallel
# Compute connected components for all blocks in parallel.
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(
tp.map(_write_block, range(n_blocks)), total=n_blocks, desc="Write blocks", disable=not verbose
Expand All @@ -202,26 +205,37 @@ def _write_block(block_id):
return out


def label(data, out=None, with_background=True, block_shape=None,
n_threads=None, mask=None, verbose=False, roi=None, connectivity=1):
"""Label the data in parallel by applying blockwise connected component and
merging the results over block boundaries.
Arguments:
data [array_like] - input data, numpy array or similar like h5py or zarr dataset
out [array_like] - output data (label cannot be applied inplace)
with_background [bool] - whether to treat zero as background label (default: True)
block_shape [tuple] - shape of the blocks used for parallelisation,
by default chunks of the input will be used, if available (default: None)
n_threads [int] - number of threads, by default all are used (default: None)
mask [array_like] - mask to exclude data from the computation.
Data not in the mask will be set to zero in the result. (default: None)
verbose [bool] - verbosity flag (default: False)
roi [tuple[slice]] - region of interest for this computation (default: None)
connectivity [int] - the number of nearest neighbor hops to consider for connection.
Currently only supports connectivity of 1. (default: 1)
def label(
data: ArrayLike,
out: Optional[ArrayLike] = None,
with_background: bool = True,
block_shape: Optional[Tuple[int, ...]] = None,
n_threads: Optional[int] = None,
mask: Optional[ArrayLike] = None,
verbose: bool = False,
roi: Optional[Tuple[slice, ...]] = None,
connectivity: int = 1,
) -> ArrayLike:
"""Label the input data in parallel.
Applies blockwise connected component and merges the results over block boundaries.
Args:
data: Input data, numpy array or similar like h5py or zarr dataset.
out: Output data. Note that `label` cannot be applied inplace.
with_background: Whether to treat zero as background label.
block_shape: Shape of the blocks to use for parallelisation,
by default chunks of the input will be used, if available.
n_threads: Number of threads, by default all available threads are used.
mask: Mask to exclude data from the computation.
Data not in the mask will be set to zero in the result.
verbose: Verbosity flag.
roi: Region of interest for this computation.
connectivity: The number of nearest neighbor hops to consider for connection.
Currently only supports connectivity of 1.
Returns:
array_like - the output data
The labeled data.
"""
if connectivity != 1:
raise NotImplementedError(
Expand All @@ -237,22 +251,22 @@ def label(data, out=None, with_background=True, block_shape=None,
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
blocking = get_blocking(data, block_shape, roi, n_threads)

# 1.) compute connected components for all blocks
# 1.) Compute connected components for all blocks.
out, offsets = cc_blocks(data, out, mask, blocking, with_background, n_threads=n_threads, verbose=verbose)

# turn block max labels into offsets
# Turn block max labels into offsets.
last_block_val = offsets[-1]
offsets = np.roll(offsets, 1)
offsets[0] = 0
offsets = np.cumsum(offsets)
max_id = offsets[-1] + last_block_val

# 2.) merge the connected components along block boundaries
# 2.) Merge the connected components along block boundaries.
mapping = merge_blocks(data, out, mask, offsets,
blocking, max_id, with_background,
n_threads=n_threads, verbose=verbose)

# 3.) write the new new pixel labeling
# 3.) Write the new new pixel labeling.
out = write_mapping(out, mask, offsets,
mapping, with_background,
blocking, n_threads=n_threads, verbose=verbose)
Expand Down
1 change: 1 addition & 0 deletions test/parallel/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_label_with_background(self):
# stress test of the label function with binary blobs
def _test_blobs(self, ndim, size):
from elf.parallel import label

block_shape = (size // 8,) * ndim
for volume_fraction in (0.05, 0.1, 0.25, 0.5):
data = binary_blobs(length=size, n_dim=ndim, volume_fraction=volume_fraction)
Expand Down

0 comments on commit 7e4380b

Please sign in to comment.