Skip to content

Commit

Permalink
Update nifti wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 1, 2023
1 parent ac83255 commit 54954f9
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 13 deletions.
1 change: 1 addition & 0 deletions .github/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- imageio
- intern
- mrcfile
- nibabel
- nifty >=1.1
- numba
- pandas
Expand Down
2 changes: 1 addition & 1 deletion elf/io/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def register_filetype(constructor, extensions=(), groups=(), datasets=()):
# add nifti extensions if we have nibabel
try:
import nibabel
register_filetype(NiftiFile, [".ni.gz"], NiftiFile, NiftiDataset)
register_filetype(NiftiFile, [".nii.gz", ".nii"], NiftiFile, NiftiDataset)
except ImportError:
nibabel = None

Expand Down
15 changes: 13 additions & 2 deletions elf/io/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path

from .extensions import (
FILE_CONSTRUCTORS, GROUP_LIKE, DATASET_LIKE,
h5py, z5py, pyn5, zarr,
Expand Down Expand Up @@ -27,12 +28,21 @@ def open_file(path, mode="a", ext=None, **kwargs):
ext [str] - file extension. This can be used to force an extension
if it cannot be inferred from the filename. (default: None)
"""

# Before checking the extension suffix, check for "protocol-style"
# cloud provider prefixes.
if "://" in path:
ext = path.split("://")[0] + "://"

ext = os.path.splitext(path.rstrip("/"))[1] if ext is None else ext
elif ext is None:
path_ = Path(path.rstrip("/"))
suffixes = path_.suffixes
# We need to treat .nii.gz differently
if len(suffixes) == 2 and "".join(suffixes) == ".nii.gz":
ext = ".nii.gz"
else:
ext = suffixes[-1]

try:
constructor = FILE_CONSTRUCTORS[ext.lower()]
except KeyError:
Expand All @@ -42,6 +52,7 @@ def open_file(path, mode="a", ext=None, **kwargs):
f"{' '.join(supported_extensions())}. "
f"You may need to install additional dependencies (h5py, z5py, zarr, intern)."
)

return constructor(path, mode=mode, **kwargs)


Expand Down
61 changes: 51 additions & 10 deletions elf/io/nifti_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Mapping

import numpy as np
try:
import nibabel
except ImportError:
Expand All @@ -8,28 +9,68 @@

class NiftiFile(Mapping):
def __init__(self, path, mode="r"):
if nibabel is None:
raise AttributeError("nibabel is required for nifti images, but is not installed.")
self.path = path
self.mode = mode
if nibabel is None:
raise AttributeError("nibabel is required for nifti images, but is not installed")
self.nifti = nibabel.load(self.path)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._f.close()
pass

# dummy attrs to be compatible with h5py/z5py/zarr API
# alternatively we could also map the header to attributes
@property
def attrs(self):
return {}

def __getitem__(self, key):
if key != "data":
raise KeyError(f"Could not find key {key}")
return NiftiDataset(self.nifti)

def __iter__(self):
yield "data"

def __len__(self):
return 1

def __contains__(self, name):
return name == "data"


# Go to https://nipy.org/nibabel/nifti_images.html for implementation.
# To be aware of when implementing slicing:
# (Pdb) x = vol[:]
# *** TypeError: Cannot slice image objects; consider using `img.slicer[slice]` to generate a sliced image
# (see documentation for caveats) or slicing image array data with `img.dataobj[slice]` or `img.get_fdata()[slice]`
class NiftiDataset:
def __init__(self, data_object):
pass
def __init__(self, data):
self._data = data

@property
def dtype(self):
return self.data.get_data_dtype()

@property
def ndim(self):
return self._data.ndim

@property
def chunks(self):
return None

@property
def shape(self):
return self._data.shape

def __getitem__(self, key):
return self._data.dataobj[key]

@property
def size(self):
return np.prod(self._data.shape)

# dummy attrs to be compatible with h5py/z5py/zarr API
# alternatively we could also map the header to attributes
@property
def attrs(self):
return {}
61 changes: 61 additions & 0 deletions test/io_tests/test_nifti_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import unittest
from glob import glob

import numpy as np

try:
import nibabel
except ImportError:
nibabel = None


@unittest.skipIf(nibabel is None, "Needs nibabel")
class TestNiftiWrapper(unittest.TestCase):

def _check_data(self, expected_data, f):
dset = f["data"]

self.assertEqual(expected_data.shape, dset.shape)
shape = dset.shape

# bounding boxes for testing sub-sampling
bbs = [np.s_[:]]
for i in range(dset.ndim):
bbs.extend([
tuple(slice(0, shape[i] // 2) if d == i else slice(None) for d in range(dset.ndim)),
tuple(slice(shape[i] // 2, None) if d == i else slice(None) for d in range(dset.ndim))
])
bbs.append(
tuple(slice(shape[i] // 4, 3 * shape[i] // 4) for i in range(dset.ndim))
)

for bb in bbs:
self.assertTrue(np.allclose(dset[bb], expected_data[bb]))

def test_read_nifti(self):
from elf.io import open_file
from nibabel.testing import data_path

paths = glob(os.path.join(data_path, "*.nii"))
for path in paths:
expected_data = np.asarray(nibabel.load(path).dataobj)
# the resampled image causes errors
if os.path.basename(path).startswith("resampled"):
continue
with open_file(path, "r") as f:
self._check_data(expected_data, f)

def test_read_nifti_compressed(self):
from elf.io import open_file
from nibabel.testing import data_path

paths = glob(os.path.join(data_path, "*.nii.gz"))
for path in paths:
expected_data = np.asarray(nibabel.load(path).dataobj)
with open_file(path, "r") as f:
self._check_data(expected_data, f)


if __name__ == "__main__":
unittest.main()

0 comments on commit 54954f9

Please sign in to comment.