diff --git a/src/python/module/z5py/dataset.py b/src/python/module/z5py/dataset.py index d1290ee..20005ff 100644 --- a/src/python/module/z5py/dataset.py +++ b/src/python/module/z5py/dataset.py @@ -51,8 +51,16 @@ def __init__(self, dset_impl, handle, parent, name, n_threads=1): self._parent = parent self._name = name - def __array__(self): - return self[...] + def __array__(self, dtype=None): + """ Create a numpy array containing the whole dataset. + + NOTE: Datasets are not interchangeble with arrays! + Every time this method is called the whole dataset is loaded into memory! + """ + arr = self[...] + if dtype is not None: + arr = arr.astype(dtype, copy=False) + return arr @staticmethod def _to_zarr_compression_options(compression, compression_options): diff --git a/src/python/test/test_asarray.py b/src/python/test/test_asarray.py index feb0db1..34d4675 100644 --- a/src/python/test/test_asarray.py +++ b/src/python/test/test_asarray.py @@ -1,5 +1,4 @@ import unittest -import pickle import sys from shutil import rmtree @@ -15,13 +14,12 @@ class TestAsarray(unittest.TestCase): def setUp(self): - sample_data = np.array([ [3, 4], [7, 8]]) + self.sample_data = np.random.randint(0, 10000, size=(10, 10)) self.test_fn = 'test.n5' self.test_ds = 'test' with z5py.File(self.test_fn, 'w') as zfh: - zfh.create_dataset(self.test_ds, data=sample_data) + zfh.create_dataset(self.test_ds, data=self.sample_data) self.ff = z5py.File(self.test_fn, 'r') - def tearDown(self): try: @@ -31,7 +29,9 @@ def tearDown(self): def test_asarray(self): uniques = np.unique(self.ff[self.test_ds]) - self.assertEqual(len(uniques), 4) + expected = np.unique(self.sample_data) + self.assertTrue(np.array_equal(uniques, expected)) + if __name__ == '__main__': unittest.main()