diff --git a/elf/label_multiset/create.py b/elf/label_multiset/create.py index 656f991..1886813 100644 --- a/elf/label_multiset/create.py +++ b/elf/label_multiset/create.py @@ -1,15 +1,19 @@ -from math import ceil +from typing import Sequence, Tuple + import numpy as np import nifty.tools as nt from .label_multiset import LabelMultiset from ..util import downscale_shape -def create_multiset_from_labels(labels): - """ Create label multiset from a regular label array. +def create_multiset_from_labels(labels: np.ndarray) -> LabelMultiset: + """Create label multiset from a regular label array. + + Args: + labels: Label array to summarize in the label multiset. - Arguments: - labels [np.ndarray] - label array to summarize. + Returns: + The label multiset. """ # argmaxs per block = labels in our case argmax = labels.flatten() @@ -18,19 +22,24 @@ def create_multiset_from_labels(labels): ids, offsets = np.unique(labels, return_inverse=True) # counts (1 by definiition) - counts = np.ones(len(ids), dtype='int32') + counts = np.ones(len(ids), dtype="int32") - multiset = LabelMultiset(argmax, offsets, ids, counts, labels.shape) - return multiset + return LabelMultiset(argmax, offsets, ids, counts, labels.shape) -def downsample_multiset(multiset, scale_factor, restrict_set=-1): - """ Downsample label multiset from other multiset. +def downsample_multiset( + multiset: LabelMultiset, scale_factor: Tuple[int, ...], restrict_set: int = -1 +) -> LabelMultiset: + """Downsample a label multiset. - Arguments: - multiset [LabelMultiset] - input label multiset. - scale_factor [list] - factor for downscaling. - restrict_set [int] - restrict entry length of down-sampled multiset (default: -1). + Args: + multiset: The input label multiset. + scale_factor: The scale factor for downsampling. + restrict_set: The maximum entry length of the downsampled multiset. + The default value (-1) means that the entry length is not restricted. + + Returns: + The downsampled label multiset. """ if not isinstance(multiset, LabelMultiset): raise ValueError("Expect input derived from MultisetBase, got %s" % type(multiset)) @@ -38,33 +47,41 @@ def downsample_multiset(multiset, scale_factor, restrict_set=-1): shape = multiset.shape blocking = nt.blocking([0] * len(shape), shape, scale_factor) - argmax, offsets, ids, counts = nt.downsampleMultiset(blocking, - multiset.offsets, multiset.entry_sizes, multiset.entry_offsets, - multiset.ids, multiset.counts, restrict_set) + argmax, offsets, ids, counts = nt.downsampleMultiset( + blocking, multiset.offsets, multiset.entry_sizes, multiset.entry_offsets, + multiset.ids, multiset.counts, restrict_set + ) new_shape = downscale_shape(shape, scale_factor) return LabelMultiset(argmax, offsets, ids, counts, new_shape) -def merge_multisets(multisets, grid_positions, shape, chunks): - """ Merge label multisets aranged in grid. +def merge_multisets( + multisets: Sequence[LabelMultiset], + grid_positions: Sequence[Tuple[int, ...]], + shape: Tuple[int, ...], + chunks: Tuple[int, ...], +) -> LabelMultiset: + """Merge label multisets aranged in grid. + + Args: + multisets: List of label multisets aranged in grid that will be merged. + grid_positions: Grid coordinates of the input multisets. + shape: Shape of the resulting multiset / grid. + chunks: Chunk shape = default shape of input multiset. - Arguments: - multisets [listlike[LabelMultiset]] - list of label multisets aranged in grid. - grid_positions [list] - list of grid coordinates of the input list. - shape [tuple] - shape of the resulting multiset / grid. - chunks [tuple] - chunk shape = default shape of input multiset. + Returns: + The merged label multiset. """ if not isinstance(multisets, (tuple, list)) and\ not all(isinstance(ms, LabelMultiset) for ms in multisets): raise ValueError("Expect list or tuple of LabelMultiset") # arrange multisets according to the grid - multisets, blocking = _compute_multiset_vector(multisets, grid_positions, - shape, chunks) + multisets, blocking = _compute_multiset_vector(multisets, grid_positions, shape, chunks) new_size = int(np.prod(shape)) - argmax = np.zeros(new_size, dtype='uint64') - offsets = np.zeros(new_size, dtype='uint64') + argmax = np.zeros(new_size, dtype="uint64") + offsets = np.zeros(new_size, dtype="uint64") def get_indices(block_id): block = blocking.getBlock(block_id) @@ -98,7 +115,7 @@ def get_indices(block_id): def _compute_multiset_vector(multisets, grid_positions, shape, chunks): - """ Arange the multisets in c-order. + """Arange the multisets in c-order. """ n_sets = len(multisets) ndim = len(shape) @@ -110,8 +127,7 @@ def _compute_multiset_vector(multisets, grid_positions, shape, chunks): raise ValueError("Invalid grid: %i, %i" % (n_blocks, n_sets)) # get the c-order positions - positions = np.array([[gp[i] for gp in grid_positions] for i in range(ndim)], - dtype='int') + positions = np.array([[gp[i] for gp in grid_positions] for i in range(ndim)], dtype="int") grid_shape = tuple(blocking.blocksPerAxis) positions = np.ravel_multi_index(positions, grid_shape) if any(pos >= n_sets for pos in positions): @@ -122,8 +138,7 @@ def _compute_multiset_vector(multisets, grid_positions, shape, chunks): mset = multisets[pos] block_shape = tuple(blocking.getBlock(pos).shape) if mset.shape != block_shape: - raise ValueError("Invalid multiset shape: %s, %s" % (str(mset.shape), - str(block_shape))) + raise ValueError("Invalid multiset shape: %s, %s" % (str(mset.shape), str(block_shape))) multiset_vector[pos] = mset if any(ms is None for ms in multiset_vector): diff --git a/elf/label_multiset/label_multiset.py b/elf/label_multiset/label_multiset.py index 1d030a7..f9118c1 100644 --- a/elf/label_multiset/label_multiset.py +++ b/elf/label_multiset/label_multiset.py @@ -1,11 +1,12 @@ +from typing import Tuple + import numpy as np import nifty.tools as nt from ..util import normalize_index class LabelMultiset: - """ Implement label multiset similar to - https://github.com/saalfeldlab/imglib2-label-multisets. + """Implements the label multiset defined in https://github.com/saalfeldlab/imglib2-label-multisets. Label multisets summarize the ids and counts of label arrays. This implementation uses flat arrays to store `ids` and `counts`. @@ -13,24 +14,23 @@ class LabelMultiset: of pixels. Further, `n_elements` refers to the number of elements (= len(ids) / len(counts)), `n_entries` refers to the number of unique multi-set entries. - Arguments: - argmax [np.ndarray] - flat array of len `size` holding max labels per set. - offsets [np.ndarray] - flat array of len `size` holding offsets into - `ids`/`counts` for each set. - ids [np.ndarray] - flat array holding the summarized label ids. - counts [np.ndarray] - flat array holding the summarized label counts. + Args: + argmax: Flat array of length `size` holding max labels per set. + offsets: Flat array of length `size` holding offsets into `ids`/`counts` for each set. + ids: Flat array holding the summarized label ids. + counts: Flat array holding the summarized label counts. + shape: Shape of the label multiset. """ - - def __init__(self, argmax, offsets, ids, counts, shape): + def __init__( + self, argmax: np.ndarray, offsets: np.ndarray, ids: np.ndarray, counts: np.ndarray, shape: Tuple[int, ...] + ): self._shape = tuple(shape) self._size = int(np.prod(list(shape))) if len(argmax) != len(offsets) != self.size: - raise ValueError("Shape, argmax and offset do not match: %i %i %i" % (len(argmax), - len(offsets), - self.size)) + raise ValueError("Shape, argmax and offset do not match: %i %i %i" % (len(argmax), len(offsets), self.size)) self.argmax = argmax - self.offsets = offsets.astype('uint64') + self.offsets = offsets.astype("uint64") if len(ids) != len(counts): raise ValueError("Ids and counts do not match: %i, %i" % (len(ids), len(counts))) @@ -42,16 +42,13 @@ def __init__(self, argmax, offsets, ids, counts, shape): # w.r.t entries instead of elements unique_offsets, self.entry_offsets = np.unique(self.offsets, return_inverse=True) if unique_offsets[-1] >= self.n_elements: - raise ValueError("Elements and offsets do not match: %i, %i" % (self.n_elements, - unique_offsets[-1])) + raise ValueError("Elements and offsets do not match: %i, %i" % (self.n_elements, unique_offsets[-1])) self.n_entries = len(unique_offsets) # compute size of the entries from unique offsets - unique_offsets = np.concatenate([unique_offsets, - np.array([self.n_elements])]).astype('uint64') + unique_offsets = np.concatenate([unique_offsets, np.array([self.n_elements])]).astype("uint64") self.entry_sizes = np.diff(unique_offsets) def __getitem__(self, key): - # get the flattened entry indices index = normalize_index(key, self._shape)[0] index = np.array([ax.flatten() for ax in np.mgrid[index]]) diff --git a/elf/label_multiset/serialize.py b/elf/label_multiset/serialize.py index bf3151c..d54853b 100644 --- a/elf/label_multiset/serialize.py +++ b/elf/label_multiset/serialize.py @@ -1,54 +1,62 @@ import struct +from typing import Tuple + import numpy as np from .label_multiset import LabelMultiset -def deserialize_labels(serialization, shape): - """ Deserialize summarized label array from multiset serialization. +def deserialize_labels(serialization: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray: + """Deserialize summarized label array from multiset serialization. + + Args: + serialization: Flat byte array with multiset serialization. + shape: Shape of the multiset. - Arguments: - serialization [np.ndarray] - flat byte array with multiset serialization. - shape [tuple] - shape of the multiset. + Returns: + The labels that were summarized by the multiset. """ # number of sets is encoded as integer in the first 4 bytes pos = 0 next_pos = 4 - size = struct.unpack('>i', serialization[pos:next_pos].tobytes())[0] + size = struct.unpack(">i", serialization[pos:next_pos].tobytes())[0] # the argmax vector is encoded as long in the next 8 * size bytes pos = next_pos next_pos += 8 * size argmax = serialization[pos:next_pos] - argmax = np.frombuffer(argmax.tobytes(), dtype='>q') + argmax = np.frombuffer(argmax.tobytes(), dtype=">q") return argmax.reshape(shape) -def deserialize_multiset(serialization, shape): - """ Deserialize label multiset. +def deserialize_multiset(serialization: np.ndarray, shape: Tuple[int, ...]) -> LabelMultiset: + """Deserialize label multiset. - Arguments: - serialization [np.ndarray] - flat byte array with multiset serialization. - shape [tuple] - shape of the multiset. + Args: + serialization: Flat byte array with multiset serialization. + shape: Shape of the multiset. + + Returns: + The deserialized label multiset. """ # number of sets is encoded as integer in the first 4 bytes pos = 0 next_pos = 4 - size = struct.unpack('>i', serialization[pos:next_pos].tobytes())[0] + size = struct.unpack(">i", serialization[pos:next_pos].tobytes())[0] # the argmax vector is encoded as long in the next 8 * size bytes pos = next_pos next_pos += 8 * size argmax = serialization[pos:next_pos] - argmax = np.frombuffer(argmax.tobytes(), dtype='>q') + argmax = np.frombuffer(argmax.tobytes(), dtype=">q") # the byte offset vector is encoded as long in the next 4 * size bytes pos = next_pos next_pos += 4 * size offsets = serialization[pos:next_pos] - offsets = np.frombuffer(offsets.tobytes(), dtype='>i') + offsets = np.frombuffer(offsets.tobytes(), dtype=">i") # compute the unique byte offsets and the inverse mapping byte_offsets, inverse_offsets = np.unique(offsets, return_inverse=True) @@ -68,13 +76,12 @@ def deserialize_entry(entry): # extract the ids and counts entry = entry.reshape((n_elements, 12)) ids = entry[:, :8].flatten() - ids = np.frombuffer(ids.tobytes(), dtype=' np.ndarray: + """Serialize label multiset serialization in imglib format. The multiset is serialized as follows: 1.) number of sets / cells encoded as integer (4 bytes) 2.) max label id for each set encoded as long (8 bytes * num_cells) 3.) offset in bytes into the data array for each set encoded as int (4 bytes * num cells) 4.) the data storing label ids / counts encoded as long / int (datalen in bytes) - cf. https://github.com/saalfeldlab/imglib2-label-multisets/blob/master/src/main/java/net - /imglib2/type/label/LabelMultisetTypeDownscaler.java#L176 + See also: + https://github.com/saalfeldlab/imglib2-label-multisets/blob/master/src/main/java/net/imglib2/type/label/LabelMultisetTypeDownscaler.java#L176 + + Args: + multiset: The label multiset to serialze. - Arguments: - multiset [LabelMultiset] - the label multiset to serialze. + Returns: + The serialized label multiset as flat binary array. """ size, n_entries, n_elements = multiset.size, multiset.n_entries, multiset.n_elements argmax, offsets, ids, counts = (multiset.argmax, multiset.offsets, multiset.ids, multiset.counts) # encode the argmax vector - argmax = np.array(argmax, dtype='>q').tobytes() + argmax = np.array(argmax, dtype=">q").tobytes() # merge and encode ids and counts. # the ids are stored as long, the counts as int (both little endian). - ids = [struct.pack('i', size) + size = struct.pack(">i", size) # combine to byte buffer for serialization serialization = size + argmax + offsets + data - serialization = np.frombuffer(serialization, dtype='uint8') + serialization = np.frombuffer(serialization, dtype="uint8") return serialization diff --git a/elf/mesh/io.py b/elf/mesh/io.py index f462a2c..5a57030 100644 --- a/elf/mesh/io.py +++ b/elf/mesh/io.py @@ -1,25 +1,48 @@ +import os +from typing import Optional, Tuple, Union + import numpy as np -def read_numpy(path): - """ Read mesh from compressed numpy format +def read_numpy(path: Union[os.PathLike, str]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Read mesh from compressed numpy format. + + Args: + path: The file with the mesh saved as compressed numpy file. + + Returns: + The vertices of the mesh. + The faces of the mesh. + The normals of the mesh. """ mesh = np.load(path) return mesh["verts"], mesh["faces"], mesh["normals"] -def write_numpy(path, verts, faces, normals): - """ Write mesh to compressed numpy format +def write_numpy(path: Union[os.PathLike, str], verts: np.ndarray, faces: np.ndarray, normals: np.ndarray): + """Write mesh to compressed numpy format. + + Args: + path: The path for saving the mesh. + verts: The vertices of the mesh. + faces: The faces of the mesh. + normals: The normals of the mesh. """ - np.savez_compressed(path, - verts=verts, - faces=faces, - normals=normals) + np.savez_compressed(path, verts=verts, faces=faces, normals=normals) # TODO support different format for faces -def read_obj(path): - """ Read mesh from obj +def read_obj(path: Union[os.PathLike, str]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Read mesh from an obj file. + + Args: + path: The file with the mesh saved as obj. + + Returns: + The vertices of the mesh. + The faces of the mesh. + The normals of the mesh. + The face normals of the mesh. """ verts = [] faces = [] @@ -51,8 +74,23 @@ def read_obj(path): # TODO support different format for faces -def write_obj(path, verts, faces, normals=None, face_normals=None, zero_based_face_index=False): - """ Write mesh to obj +def write_obj( + path: Union[os.PathLike, str], + verts: np.ndarray, + faces: np.ndarray, + normals: Optional[np.ndarray] = None, + face_normals: Optional[np.ndarray] = None, + zero_based_face_index: bool = False, +): + """Write mesh to an obj file. + + Args: + path: The path for saving the mesh. + verts: The vertices of the mesh. + faces: The faces of the mesh. + normals: The normals of the mesh. + face_normals: The face normals of the mesh. + zero_based_face_index: Whether to use 0- or 1- based indexing for the faces. """ with open(path, "w") as f: for vert in verts: @@ -79,13 +117,19 @@ def write_obj(path, verts, faces, normals=None, face_normals=None, zero_based_fa f.write("\n") else: for face, normal in zip(faces, face_normals): - f.write(" ".join(["f"] + ["/".join([str(fa), "1", str(no)]) - for fa, no in zip(face, normal)])) + f.write(" ".join(["f"] + ["/".join([str(fa), "1", str(no)]) for fa, no in zip(face, normal)])) f.write("\n") -def read_ply(path): +def read_ply(path: Union[os.PathLike, str]) -> Tuple[np.ndarray, np.ndarray]: """Read mesh from ply data format. + + Args: + path: The file with the mesh saved as ply. + + Returns: + The vertices of the mesh. + The faces of the mesh. """ verts = [] faces = [] @@ -124,8 +168,13 @@ def read_ply(path): # https://web.archive.org/web/20161221115231/http://www.cs.virginia.edu/~gfx/Courses/2001/Advanced.spring.01/plylib/Ply.txt -def write_ply(path, verts, faces): +def write_ply(path: Union[os.PathLike, str], verts: np.ndarray, faces: np.ndarray): """Write mesh to ply data format. + + Args: + path: The path for saving the mesh. + verts: The vertices of the mesh. + faces: The faces of the mesh. """ header = f"""ply diff --git a/elf/mesh/mesh.py b/elf/mesh/mesh.py index 4dc89b6..84fd03c 100644 --- a/elf/mesh/mesh.py +++ b/elf/mesh/mesh.py @@ -1,39 +1,56 @@ +from typing import Optional, Tuple + import nifty import numpy as np -from skimage.measure import marching_cubes as marching_cubes_lewiner +from skimage.measure import marching_cubes as marching_cubes_impl -def marching_cubes(obj, smoothing_iterations=0, resolution=None): +def marching_cubes( + obj: np.ndarray, + smoothing_iterations: int = 0, + resolution: Optional[Tuple[float, float, float]] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Compute mesh via marching cubes. This is a wrapper around the skimage marching cubes implementation that provides - additional mesh smoothing + additional mesh smoothing. + + Args: + obj: Volume containing the object to be meshed. + smoothing_iterations: Number of mesh smoothing iterations. + resolution: Resolution of the data. - Arguments: - obj [np.ndarray] - volume containing the object to be meshed - smoothing_iterations [int] - number of mesh smoothing iterations (default: 0) - resolution[listlike[int]] - resolution of the data (default: None) + Returns: + The vertices of the mesh. + The faces of the mesh. + The normals of the mesh. """ - resolution = (1., 1., 1.) if resolution is None else resolution + resolution = (1.0, 1.0, 1.0) if resolution is None else resolution if len(resolution) != 3: raise ValueError(f"Invalid resolution argument: {resolution}") resolution = tuple(resolution) - verts, faces, normals, _ = marching_cubes_lewiner(obj, spacing=resolution) + verts, faces, normals, _ = marching_cubes_impl(obj, spacing=resolution) if smoothing_iterations > 0: verts, normals = smooth_mesh(verts, normals, faces, smoothing_iterations) return verts, faces, normals -def smooth_mesh(verts, normals, faces, iterations): - """ Smooth mesh surfacee via laplacian smoothing. +def smooth_mesh( + verts: np.ndarray, normals: np.ndarray, faces: np.ndarray, iterations: int +) -> Tuple[np.ndarray, np.ndarray]: + """Smooth mesh surface via laplacian smoothing. + + Args: + verts: The mesh vertices. + normals: The mesh normals. + faces: The mesh faces. + iterations: The number of smoothing iterations. - Arguments: - verts [np.ndarray] - mesh vertices - normals [np.ndarray] - mesh normals - faces [np.ndarray] - mesh faces - iterations [int] - number of smoothing rounds + Returns: + The vertices after smoothing. + The normals after smoothing. """ n_verts = len(verts) g = nifty.graph.undirectedGraph(n_verts) @@ -46,7 +63,7 @@ def smooth_mesh(verts, normals, faces, iterations): new_verts = np.zeros_like(verts, dtype=verts.dtype) new_normals = np.zeros_like(normals, dtype=normals.dtype) - # TODO implement this directly in nifty for speed up + # Implement this directly in nifty for speed up? for it in range(iterations): for vert in range(n_verts): nbrs = np.array([vert] + [nbr[0] for nbr in g.nodeAdjacency(vert)], dtype="int") diff --git a/elf/mesh/mesh_to_segmentation.py b/elf/mesh/mesh_to_segmentation.py index 3b5ab9a..0dfaffa 100644 --- a/elf/mesh/mesh_to_segmentation.py +++ b/elf/mesh/mesh_to_segmentation.py @@ -1,5 +1,6 @@ import tempfile import warnings +from typing import Optional, Tuple import numpy as np import vigra @@ -16,8 +17,10 @@ def vertices_and_faces_to_segmentation( - vertices, faces, resolution=[1.0, 1.0, 1.0], shape=None, verbose=False, block_shape=None + vertices, faces, resolution=(1.0, 1.0, 1.0), shape=None, verbose=False, block_shape=None ): + """@private + """ with tempfile.NamedTemporaryFile(suffix=".obj") as f: tmp_path = f.name write_obj(tmp_path, vertices, faces) @@ -25,21 +28,28 @@ def vertices_and_faces_to_segmentation( return seg -def mesh_to_segmentation(mesh_file, resolution=[1.0, 1.0, 1.0], - reverse_coordinates=False, shape=None, verbose=False, - block_shape=None): - """ Compute segmentation volume from mesh. +def mesh_to_segmentation( + mesh_file: str, + resolution: Tuple[float, float, float] = (1.0, 1.0, 1.0), + reverse_coordinates: bool = False, + shape: Tuple[int, int, int] = None, + verbose: bool = False, + block_shape: Optional[Tuple[int, int, int]] = None, +) -> np.ndarray: + """Transform a mesh into a volumetric binary segmentation mask. Requires madcad and pywavefront as dependency. - Arguments: - mesh_file [str] - path to mesh in obj format - resolution [list[float]] - pixel resolution of the vertex coordinates - reverse_coordinates [bool] - whether to reverse the coordinate order (default: False) - shape [tuple[int]] - shape of the output volume. - If None, the maximal extent of the mesh coordinates will be used as shape (default: None) - verbose [bool] - whether to activate verbose output (default: False) - block_shape [tuple[int]] - block_shape to parallelize the computation (default: None) + Args: + mesh_file: Path to the mesh in obj format. + resolution: Pixel resolution of the vertex coordinates. + reverse_coordinates: Whether to reverse the coordinate order. + shape: Shape of the output volume. If None, the maximal extent of the mesh coordinates will be used. + verbose: Whether to activate verbose output. + block_shape: Block_shape to parallelize the computation. + + Returns: + The binary segmentation mask. """ if PositionMap is None: raise RuntimeError("Need madcad dependency for mesh_to_seg functionality.") @@ -52,12 +62,7 @@ def mesh_to_segmentation(mesh_file, resolution=[1.0, 1.0, 1.0], voxel = hasher.keysfor(mesh.facepoints(face)) if reverse_coordinates: voxel = [vox[::-1] for vox in voxel] - voxel = [ - tuple( - int(vv / res) for vv, res in zip(vox, resolution) - ) - for vox in voxel - ] + voxel = [tuple(int(vv / res) for vv, res in zip(vox, resolution)) for vox in voxel] voxels.update(voxel) voxels = np.array(list(voxels))