Skip to content

Commit

Permalink
Merge pull request #84 from constantinpape/parallel-edt
Browse files Browse the repository at this point in the history
Parallel distance transform implementation
  • Loading branch information
constantinpape authored Oct 22, 2023
2 parents 40c2e09 + 0c4f79d commit 37646ea
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 15 deletions.
File renamed without changes.
20 changes: 6 additions & 14 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,16 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.9]
python-version: ["3.10"]

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v2
uses: actions/checkout@v4
- name: Setup micromamba
uses: mamba-org/setup-micromamba@v1
with:
activate-environment: elf-dev
mamba-version: "*"
auto-update-conda: true
channels: conda-forge
environment-file: .github/workflows/environment.yaml
python-version: ${{ matrix.python-version }}
auto-activate-base: false
env:
ACTIONS_ALLOW_UNSECURE_COMMANDS: true
environment-file: .github/environment.yaml

- name: Install package
shell: bash -l {0}
Expand Down
1 change: 1 addition & 0 deletions elf/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .copy_dataset import copy_dataset
from .distance_transform import distance_transform
from .operations import (apply_operation, add, divide, multiply, subtract,
greater, greater_equal, less, less_equal,
minimum, maximum, isin)
Expand Down
48 changes: 48 additions & 0 deletions elf/parallel/distance_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask, so that we can also run this on the cluster
from concurrent import futures

import numpy as np
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm

from .common import get_blocking


# TODO other distance transform arguments
def distance_transform(
data,
halo,
out=None,
block_shape=None,
n_threads=None,
verbose=False,
roi=None,
):
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
blocking = get_blocking(data, block_shape, roi, n_threads)

if out is None:
out = np.zeros(data.shape, dtype="float32")

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def dist_block(block_id):
block = blocking.getBlockWithHalo(block_id, list(halo))
outer_bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end))
block_data = data[outer_bb]
dist = distance_transform_edt(block_data)
inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
local_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end))
out[inner_bb] = dist[local_bb]

n_blocks = blocking.numberOfBlocks
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(
tp.map(dist_block, range(n_blocks)), total=n_blocks,
desc="Compute distance transform", disable=not verbose
))

return out
2 changes: 1 addition & 1 deletion test/io_tests/test_intern_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@unittest.skipIf(array is None, "Needs intern (pip install intern)")
@unittest.expectedFailure
# @unittest.expectedFailure
class TestInternWrapper(unittest.TestCase):

# the address is currently not available
Expand Down
33 changes: 33 additions & 0 deletions test/parallel/test_distance_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

import numpy as np
from skimage.data import binary_blobs
from scipy.ndimage import distance_transform_edt


class TestDistanceTransform(unittest.TestCase):
def _check_result(self, data, result, tolerance):
expected = distance_transform_edt(data)

tolerance_mask = expected < tolerance
self.assertTrue(np.allclose(result[tolerance_mask], expected[tolerance_mask]))

def test_distance_transform_2d(self):
from elf.parallel import distance_transform

tolerance = 64
data = binary_blobs(length=512, n_dim=2, volume_fraction=0.2)
result = distance_transform(data, halo=(tolerance, tolerance), block_shape=(128, 128))
self._check_result(data, result, tolerance)

def test_distance_transform_3d(self):
from elf.parallel import distance_transform

tolerance = 16
data = binary_blobs(length=128, n_dim=3, volume_fraction=0.2)
result = distance_transform(data, halo=(tolerance, tolerance, tolerance), block_shape=(64, 64, 64))
self._check_result(data, result, tolerance)


if __name__ == "__main__":
unittest.main()

0 comments on commit 37646ea

Please sign in to comment.