diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 3c2a410f..30dbf848 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from .array_source import ArraySource from .add_affinities import AddAffinities from .astype import AsType from .balance_labels import BalanceLabels diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py index 4ebbbb26..e6f05507 100644 --- a/gunpowder/nodes/array_source.py +++ b/gunpowder/nodes/array_source.py @@ -1,5 +1,8 @@ from funlib.persistence.arrays import Array as PersistenceArray -from gunpowder import Array, ArrayKey, Batch, BatchProvider, ArraySpec +from gunpowder.array import Array, ArrayKey +from gunpowder.array_spec import ArraySpec +from gunpowder.batch import Batch +from .batch_provider import BatchProvider class ArraySource(BatchProvider): @@ -33,20 +36,17 @@ def __init__( self.array_spec = ArraySpec( self.array.roi, self.array.voxel_size, - self.interpolatable, - self.nonspatial, + interpolatable, + nonspatial, self.array.dtype, ) - self.interpolatable = interpolatable - self.nonspatial = nonspatial - def setup(self): self.provides(self.key, self.array_spec) def provide(self, request): outputs = Batch() - if self.nonspatial: + if self.array_spec.nonspatial: outputs[self.key] = Array(self.array[:], self.array_spec.copy()) else: out_spec = self.array_spec.copy() diff --git a/tests/cases/array_source.py b/tests/cases/array_source.py new file mode 100644 index 00000000..f7cb666b --- /dev/null +++ b/tests/cases/array_source.py @@ -0,0 +1,29 @@ +from funlib.persistence import prepare_ds +from funlib.geometry import Roi +from gunpowder.nodes import ArraySource +from gunpowder import ArrayKey, build, BatchRequest, ArraySpec + +import numpy as np + + +def test_array_source(tmpdir): + array = prepare_ds( + tmpdir / "data.zarr", + shape=(100, 102, 108), + offset=(100, 50, 0), + voxel_size=(1, 2, 3), + dtype="uint8", + ) + array[:] = np.arange(100 * 102 * 108).reshape((100, 102, 108)) % 255 + + key = ArrayKey("TEST") + + source = ArraySource(key=key, array=array) + + with build(source): + request = BatchRequest() + + roi = Roi((100, 100, 102), (30, 30, 30)) + request[key] = ArraySpec(roi) + + assert np.array_equal(source.request_batch(request)[key].data, array[roi])