Skip to content

Commit

Permalink
Update more doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 30, 2024
1 parent 4527ab4 commit 7216d8d
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 142 deletions.
81 changes: 48 additions & 33 deletions elf/label_multiset/create.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from math import ceil
from typing import Sequence, Tuple

import numpy as np
import nifty.tools as nt
from .label_multiset import LabelMultiset
from ..util import downscale_shape


def create_multiset_from_labels(labels):
""" Create label multiset from a regular label array.
def create_multiset_from_labels(labels: np.ndarray) -> LabelMultiset:
"""Create label multiset from a regular label array.
Args:
labels: Label array to summarize in the label multiset.
Arguments:
labels [np.ndarray] - label array to summarize.
Returns:
The label multiset.
"""
# argmaxs per block = labels in our case
argmax = labels.flatten()
Expand All @@ -18,53 +22,66 @@ def create_multiset_from_labels(labels):
ids, offsets = np.unique(labels, return_inverse=True)

# counts (1 by definiition)
counts = np.ones(len(ids), dtype='int32')
counts = np.ones(len(ids), dtype="int32")

multiset = LabelMultiset(argmax, offsets, ids, counts, labels.shape)
return multiset
return LabelMultiset(argmax, offsets, ids, counts, labels.shape)


def downsample_multiset(multiset, scale_factor, restrict_set=-1):
""" Downsample label multiset from other multiset.
def downsample_multiset(
multiset: LabelMultiset, scale_factor: Tuple[int, ...], restrict_set: int = -1
) -> LabelMultiset:
"""Downsample a label multiset.
Arguments:
multiset [LabelMultiset] - input label multiset.
scale_factor [list] - factor for downscaling.
restrict_set [int] - restrict entry length of down-sampled multiset (default: -1).
Args:
multiset: The input label multiset.
scale_factor: The scale factor for downsampling.
restrict_set: The maximum entry length of the downsampled multiset.
The default value (-1) means that the entry length is not restricted.
Returns:
The downsampled label multiset.
"""
if not isinstance(multiset, LabelMultiset):
raise ValueError("Expect input derived from MultisetBase, got %s" % type(multiset))

shape = multiset.shape
blocking = nt.blocking([0] * len(shape), shape, scale_factor)

argmax, offsets, ids, counts = nt.downsampleMultiset(blocking,
multiset.offsets, multiset.entry_sizes, multiset.entry_offsets,
multiset.ids, multiset.counts, restrict_set)
argmax, offsets, ids, counts = nt.downsampleMultiset(
blocking, multiset.offsets, multiset.entry_sizes, multiset.entry_offsets,
multiset.ids, multiset.counts, restrict_set
)
new_shape = downscale_shape(shape, scale_factor)
return LabelMultiset(argmax, offsets, ids, counts, new_shape)


def merge_multisets(multisets, grid_positions, shape, chunks):
""" Merge label multisets aranged in grid.
def merge_multisets(
multisets: Sequence[LabelMultiset],
grid_positions: Sequence[Tuple[int, ...]],
shape: Tuple[int, ...],
chunks: Tuple[int, ...],
) -> LabelMultiset:
"""Merge label multisets aranged in grid.
Args:
multisets: List of label multisets aranged in grid that will be merged.
grid_positions: Grid coordinates of the input multisets.
shape: Shape of the resulting multiset / grid.
chunks: Chunk shape = default shape of input multiset.
Arguments:
multisets [listlike[LabelMultiset]] - list of label multisets aranged in grid.
grid_positions [list] - list of grid coordinates of the input list.
shape [tuple] - shape of the resulting multiset / grid.
chunks [tuple] - chunk shape = default shape of input multiset.
Returns:
The merged label multiset.
"""
if not isinstance(multisets, (tuple, list)) and\
not all(isinstance(ms, LabelMultiset) for ms in multisets):
raise ValueError("Expect list or tuple of LabelMultiset")

# arrange multisets according to the grid
multisets, blocking = _compute_multiset_vector(multisets, grid_positions,
shape, chunks)
multisets, blocking = _compute_multiset_vector(multisets, grid_positions, shape, chunks)

new_size = int(np.prod(shape))
argmax = np.zeros(new_size, dtype='uint64')
offsets = np.zeros(new_size, dtype='uint64')
argmax = np.zeros(new_size, dtype="uint64")
offsets = np.zeros(new_size, dtype="uint64")

def get_indices(block_id):
block = blocking.getBlock(block_id)
Expand Down Expand Up @@ -98,7 +115,7 @@ def get_indices(block_id):


def _compute_multiset_vector(multisets, grid_positions, shape, chunks):
""" Arange the multisets in c-order.
"""Arange the multisets in c-order.
"""
n_sets = len(multisets)
ndim = len(shape)
Expand All @@ -110,8 +127,7 @@ def _compute_multiset_vector(multisets, grid_positions, shape, chunks):
raise ValueError("Invalid grid: %i, %i" % (n_blocks, n_sets))

# get the c-order positions
positions = np.array([[gp[i] for gp in grid_positions] for i in range(ndim)],
dtype='int')
positions = np.array([[gp[i] for gp in grid_positions] for i in range(ndim)], dtype="int")
grid_shape = tuple(blocking.blocksPerAxis)
positions = np.ravel_multi_index(positions, grid_shape)
if any(pos >= n_sets for pos in positions):
Expand All @@ -122,8 +138,7 @@ def _compute_multiset_vector(multisets, grid_positions, shape, chunks):
mset = multisets[pos]
block_shape = tuple(blocking.getBlock(pos).shape)
if mset.shape != block_shape:
raise ValueError("Invalid multiset shape: %s, %s" % (str(mset.shape),
str(block_shape)))
raise ValueError("Invalid multiset shape: %s, %s" % (str(mset.shape), str(block_shape)))
multiset_vector[pos] = mset

if any(ms is None for ms in multiset_vector):
Expand Down
35 changes: 16 additions & 19 deletions elf/label_multiset/label_multiset.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
from typing import Tuple

import numpy as np
import nifty.tools as nt
from ..util import normalize_index


class LabelMultiset:
""" Implement label multiset similar to
https://github.com/saalfeldlab/imglib2-label-multisets.
"""Implements the label multiset defined in https://github.com/saalfeldlab/imglib2-label-multisets.
Label multisets summarize the ids and counts of label arrays.
This implementation uses flat arrays to store `ids` and `counts`.
The member variables `shape` and `size` refer to the summarized label array.
of pixels. Further, `n_elements` refers to the number of elements
(= len(ids) / len(counts)), `n_entries` refers to the number of unique multi-set entries.
Arguments:
argmax [np.ndarray] - flat array of len `size` holding max labels per set.
offsets [np.ndarray] - flat array of len `size` holding offsets into
`ids`/`counts` for each set.
ids [np.ndarray] - flat array holding the summarized label ids.
counts [np.ndarray] - flat array holding the summarized label counts.
Args:
argmax: Flat array of length `size` holding max labels per set.
offsets: Flat array of length `size` holding offsets into `ids`/`counts` for each set.
ids: Flat array holding the summarized label ids.
counts: Flat array holding the summarized label counts.
shape: Shape of the label multiset.
"""

def __init__(self, argmax, offsets, ids, counts, shape):
def __init__(
self, argmax: np.ndarray, offsets: np.ndarray, ids: np.ndarray, counts: np.ndarray, shape: Tuple[int, ...]
):
self._shape = tuple(shape)
self._size = int(np.prod(list(shape)))

if len(argmax) != len(offsets) != self.size:
raise ValueError("Shape, argmax and offset do not match: %i %i %i" % (len(argmax),
len(offsets),
self.size))
raise ValueError("Shape, argmax and offset do not match: %i %i %i" % (len(argmax), len(offsets), self.size))
self.argmax = argmax
self.offsets = offsets.astype('uint64')
self.offsets = offsets.astype("uint64")

if len(ids) != len(counts):
raise ValueError("Ids and counts do not match: %i, %i" % (len(ids), len(counts)))
Expand All @@ -42,16 +42,13 @@ def __init__(self, argmax, offsets, ids, counts, shape):
# w.r.t entries instead of elements
unique_offsets, self.entry_offsets = np.unique(self.offsets, return_inverse=True)
if unique_offsets[-1] >= self.n_elements:
raise ValueError("Elements and offsets do not match: %i, %i" % (self.n_elements,
unique_offsets[-1]))
raise ValueError("Elements and offsets do not match: %i, %i" % (self.n_elements, unique_offsets[-1]))
self.n_entries = len(unique_offsets)
# compute size of the entries from unique offsets
unique_offsets = np.concatenate([unique_offsets,
np.array([self.n_elements])]).astype('uint64')
unique_offsets = np.concatenate([unique_offsets, np.array([self.n_elements])]).astype("uint64")
self.entry_sizes = np.diff(unique_offsets)

def __getitem__(self, key):

# get the flattened entry indices
index = normalize_index(key, self._shape)[0]
index = np.array([ax.flatten() for ax in np.mgrid[index]])
Expand Down
86 changes: 48 additions & 38 deletions elf/label_multiset/serialize.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,62 @@
import struct
from typing import Tuple

import numpy as np
from .label_multiset import LabelMultiset


def deserialize_labels(serialization, shape):
""" Deserialize summarized label array from multiset serialization.
def deserialize_labels(serialization: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
"""Deserialize summarized label array from multiset serialization.
Args:
serialization: Flat byte array with multiset serialization.
shape: Shape of the multiset.
Arguments:
serialization [np.ndarray] - flat byte array with multiset serialization.
shape [tuple] - shape of the multiset.
Returns:
The labels that were summarized by the multiset.
"""

# number of sets is encoded as integer in the first 4 bytes
pos = 0
next_pos = 4
size = struct.unpack('>i', serialization[pos:next_pos].tobytes())[0]
size = struct.unpack(">i", serialization[pos:next_pos].tobytes())[0]

# the argmax vector is encoded as long in the next 8 * size bytes
pos = next_pos
next_pos += 8 * size
argmax = serialization[pos:next_pos]
argmax = np.frombuffer(argmax.tobytes(), dtype='>q')
argmax = np.frombuffer(argmax.tobytes(), dtype=">q")

return argmax.reshape(shape)


def deserialize_multiset(serialization, shape):
""" Deserialize label multiset.
def deserialize_multiset(serialization: np.ndarray, shape: Tuple[int, ...]) -> LabelMultiset:
"""Deserialize label multiset.
Arguments:
serialization [np.ndarray] - flat byte array with multiset serialization.
shape [tuple] - shape of the multiset.
Args:
serialization: Flat byte array with multiset serialization.
shape: Shape of the multiset.
Returns:
The deserialized label multiset.
"""

# number of sets is encoded as integer in the first 4 bytes
pos = 0
next_pos = 4
size = struct.unpack('>i', serialization[pos:next_pos].tobytes())[0]
size = struct.unpack(">i", serialization[pos:next_pos].tobytes())[0]

# the argmax vector is encoded as long in the next 8 * size bytes
pos = next_pos
next_pos += 8 * size
argmax = serialization[pos:next_pos]
argmax = np.frombuffer(argmax.tobytes(), dtype='>q')
argmax = np.frombuffer(argmax.tobytes(), dtype=">q")

# the byte offset vector is encoded as long in the next 4 * size bytes
pos = next_pos
next_pos += 4 * size
offsets = serialization[pos:next_pos]
offsets = np.frombuffer(offsets.tobytes(), dtype='>i')
offsets = np.frombuffer(offsets.tobytes(), dtype=">i")

# compute the unique byte offsets and the inverse mapping
byte_offsets, inverse_offsets = np.unique(offsets, return_inverse=True)
Expand All @@ -68,13 +76,12 @@ def deserialize_entry(entry):
# extract the ids and counts
entry = entry.reshape((n_elements, 12))
ids = entry[:, :8].flatten()
ids = np.frombuffer(ids.tobytes(), dtype='<q')
ids = np.frombuffer(ids.tobytes(), dtype="<q")
counts = entry[:, 8:].flatten()
counts = np.frombuffer(counts.tobytes(), dtype='<i')
counts = np.frombuffer(counts.tobytes(), dtype="<i")
return ids, counts

data_offsets = np.concatenate([byte_offsets,
np.array([len(data)], dtype=byte_offsets.dtype)])
data_offsets = np.concatenate([byte_offsets, np.array([len(data)], dtype=byte_offsets.dtype)])
# TODO vectorize
ids, counts = [], []
entry_offsets = []
Expand All @@ -84,8 +91,8 @@ def deserialize_entry(entry):
counts.extend(mcounts)
entry_offsets.append(len(mids))

ids = np.array(ids, dtype='uint64')
counts = np.array(counts, dtype='int32')
ids = np.array(ids, dtype="uint64")
counts = np.array(counts, dtype="int32")

# compute the set offsets from bye offsets and entry offsets
entry_offsets = np.concatenate([np.array([0]), entry_offsets[:-1]])
Expand All @@ -97,33 +104,36 @@ def deserialize_entry(entry):


# apparently, we do not need to switch to fortran order for the
# serialization, but that should be duoble checked.
def serialize_multiset(multiset):
""" Compute multiset serialization in imglib format.
# serialization, but that should be double checked.
def serialize_multiset(multiset: LabelMultiset) -> np.ndarray:
"""Serialize label multiset serialization in imglib format.
The multiset is serialized as follows:
1.) number of sets / cells encoded as integer (4 bytes)
2.) max label id for each set encoded as long (8 bytes * num_cells)
3.) offset in bytes into the data array for each set encoded as int (4 bytes * num cells)
4.) the data storing label ids / counts encoded as long / int (datalen in bytes)
cf. https://github.com/saalfeldlab/imglib2-label-multisets/blob/master/src/main/java/net
/imglib2/type/label/LabelMultisetTypeDownscaler.java#L176
See also:
https://github.com/saalfeldlab/imglib2-label-multisets/blob/master/src/main/java/net/imglib2/type/label/LabelMultisetTypeDownscaler.java#L176
Args:
multiset: The label multiset to serialze.
Arguments:
multiset [LabelMultiset] - the label multiset to serialze.
Returns:
The serialized label multiset as flat binary array.
"""
size, n_entries, n_elements = multiset.size, multiset.n_entries, multiset.n_elements
argmax, offsets, ids, counts = (multiset.argmax, multiset.offsets,
multiset.ids, multiset.counts)
# encode the argmax vector
argmax = np.array(argmax, dtype='>q').tobytes()
argmax = np.array(argmax, dtype=">q").tobytes()

# merge and encode ids and counts.
# the ids are stored as long, the counts as int (both little endian).
ids = [struct.pack('<q', i) for i in ids]
counts = [struct.pack('<i', c) for c in counts]
ids = [struct.pack("<q", i) for i in ids]
counts = [struct.pack("<i", c) for c in counts]
# get list of the unique offsets to delineate entries
offset_list = np.concatenate([np.unique(offsets), np.array([n_elements])]).astype('uint64')
offset_list = np.concatenate([np.unique(offsets), np.array([n_elements])]).astype("uint64")
assert offset_list[-2] < n_elements, "%i, %i" % (offset_list[-2], n_elements)

# zip entry_sizesm ids and counts into one list.
Expand All @@ -137,25 +147,25 @@ def serialize_multiset(multiset):
for beg, end in zip(offset_list[:-1], offset_list[1:])]

# encode the data. we also prepend the entry size encoded as int for java
entry_sizes = [struct.pack('<i', es) for es in multiset.entry_sizes]
data = [b''.join([es] + elem) for es, elem in zip(entry_sizes, data)]
entry_sizes = [struct.pack("<i", es) for es in multiset.entry_sizes]
data = [b"".join([es] + elem) for es, elem in zip(entry_sizes, data)]
assert len(data) == n_entries

# comupute the byte offsets for each entry in data
data_offsets = np.cumsum([0] + [len(entry) for entry in data[:-1]])
assert len(data_offsets) == n_entries

# encode the data
data = b''.join(data)
data = b"".join(data)

# get the offsets in bytes and encode
offsets = [data_offsets[off] for off in multiset.entry_offsets]
offsets = np.array(offsets, dtype='>i').tobytes()
offsets = np.array(offsets, dtype=">i").tobytes()

# encode the number of sets
size = struct.pack('>i', size)
size = struct.pack(">i", size)

# combine to byte buffer for serialization
serialization = size + argmax + offsets + data
serialization = np.frombuffer(serialization, dtype='uint8')
serialization = np.frombuffer(serialization, dtype="uint8")
return serialization
Loading

0 comments on commit 7216d8d

Please sign in to comment.