Skip to content

Commit

Permalink
Merge pull request #107 from constantinpape/tracking-updates
Browse files Browse the repository at this point in the history
Fix issues and add doc strings in tracking functionality
  • Loading branch information
constantinpape authored Dec 28, 2024
2 parents ae85756 + 57482cb commit 2eda443
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 88 deletions.
37 changes: 23 additions & 14 deletions elf/tracking/mamut.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Tuple

import xml.etree.ElementTree as ET
import numpy as np

#
# adapted from code shared by @wolny
# Adapted from code shared by @wolny.
#


Expand All @@ -29,31 +31,38 @@ def _extract_tracks(root, flatten_spots=True):
return tracks


def extract_tracks_as_volume(path, timepoint, shape, voxel_size, binary=False):
def extract_tracks_as_volume(
path: str,
timepoint: int,
shape: Tuple[int, int, int],
voxel_size: Tuple[float, float, float],
binary: bool = False
) -> np.ndarray:
"""Extract tracks as volume from MaMuT xml.
Arguments:
path [str] - path to the xml file with tracks
timepoint [int] - timepoint for which to extract the tracks
shape [tuple[int]] - shape of the output volume
voxel_size [tuple[float]] - voxel size
path: Path to the xml file with tracks stored in MaMuT format.
timepoint: Timepoint for which to extract the tracks.
shape: Shape of the output volume.
voxel_size: Voxel size of the volume.
binary: Whether to return the volume as binary labels and not instance ids.
Returns:
np.ndarray -
The volume with instance ids or binary ids.
"""
# get root XML element
root = ET.parse(path).getroot()
# retrieve all of the spots
all_spots = root.find('Model').find('AllSpots')
all_spots = root.find("Model").find("AllSpots")

# get all spots for a given time frame
spots = next((s for s in all_spots if int(s.attrib['frame']) == timepoint), None)
spots = next((s for s in all_spots if int(s.attrib["frame"]) == timepoint), None)

if spots is None:
raise RuntimeError('Could not find spots for time frame:', timepoint)
raise RuntimeError("Could not find spots for time frame:", timepoint)

# get pixel coordinates
pixel_coordinates = np.array([_to_zyx_coordinates(
[spot.attrib['POSITION_Z'], spot.attrib['POSITION_Y'], spot.attrib['POSITION_X']],
[spot.attrib["POSITION_Z"], spot.attrib["POSITION_Y"], spot.attrib["POSITION_X"]],
np.array([vsize for vsize in voxel_size])
) for spot in spots])
z = pixel_coordinates[:, 0]
Expand All @@ -62,13 +71,13 @@ def extract_tracks_as_volume(path, timepoint, shape, voxel_size, binary=False):

# extract the volume as binary
if binary:
spot_mask = np.zeros(shape, dtype='bool')
spot_mask = np.zeros(shape, dtype="bool")
spot_mask[z, y, x] = 1
return spot_mask

# extract volume with track ids
spot_ids = [int(spot.attrib['ID']) for spot in spots]
track_volume = np.zeros(shape, dtype='uint32')
spot_ids = [int(spot.attrib["ID"]) for spot in spots]
track_volume = np.zeros(shape, dtype="uint32")

tracks = _extract_tracks(root)
spots_to_tracks = {spot: track for track, spots in tracks.items() for spot in spots}
Expand Down
Loading

0 comments on commit 2eda443

Please sign in to comment.