Skip to content

Commit

Permalink
Initial test for tensoflow model
Browse files Browse the repository at this point in the history
  • Loading branch information
m-novikov committed Oct 19, 2020
1 parent 614426b commit 3ed0d8c
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 5 deletions.
29 changes: 28 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TEST_DATA = "data"
TEST_PYBIO_ZIPFOLDER = "unet2d"
TEST_PYBIO_DUMMY = "dummy"
TEST_PYBIO_TENSORFLOW_DUMMY = "dummy_tensorflow"

NNModel = namedtuple("NNModel", ["model", "state"])

Expand Down Expand Up @@ -92,7 +93,6 @@ def pybio_model_bytes(data_path):

return data


@pytest.fixture
def pybio_model_zipfile(pybio_model_bytes):
with ZipFile(pybio_model_bytes, mode="r") as zf:
Expand All @@ -114,6 +114,33 @@ def pybio_dummy_model_bytes(data_path):
return data


def archive(directory):
result = io.BytesIO()

with ZipFile(result, mode="w") as zip_model:
def _archive(path_to_archive):
for path in path_to_archive.iterdir():
if str(path.name).startswith("__"):
continue

if path.is_dir():
_archive(path)

else:
with path.open(mode="rb") as f:
zip_model.writestr(str(path).replace(str(directory), ""), f.read())

_archive(directory)

return result


@pytest.fixture
def pybio_dummy_tensorflow_model_bytes(data_path):
pybio_net_dir = Path(data_path) / TEST_PYBIO_TENSORFLOW_DUMMY
return archive(pybio_net_dir)


@pytest.fixture
def cache_path(tmp_path):
return Path(getenv("PYBIO_CACHE_PATH", tmp_path))
43 changes: 43 additions & 0 deletions tests/data/dummy_tensorflow/Dummy.model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: DummyTFModel
description: A dummy tensorflow model for testing
authors:
- ilastik team
cite:
- text: "Ilastik"
doi: https://doi.org
documentation: dummy.md
tags: [tensorflow]
license: MIT

format_version: 0.1.0
language: python
framework: tensorflow

source: dummy.py::TensorflowModelWrapper

test_input: null # ../test_input.npy
test_output: null # ../test_output.npy

# TODO double check inputs/outputs
inputs:
- name: input
axes: bcyx
data_type: float32
data_range: [-inf, inf]
shape: [1, 1, 128, 128]
outputs:
- name: output
axes: bcyx
data_type: float32
data_range: [0, 1]
shape:
reference_input: input # FIXME(m-novikov) ignoring for now
scale: [1, 1, 1, 1]
offset: [0, 0, 0, 0]
#halo: [0, 0, 32, 32] # Should be moved to outputs

prediction:
weights:
source: ./model
hash: {md5: TODO}
dependencies: conda:./environment.yaml
12 changes: 12 additions & 0 deletions tests/data/dummy_tensorflow/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class TensorflowModelWrapper:
def __init__(self):
self._model = None

def set_model(self, model):
self._model = model

def forward(self, input_):
return self._model.predict(input_)

def __call__(self, *args, **kwargs):
return self._model.predict(*args, **kwargs)
Binary file added tests/data/dummy_tensorflow/model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions tests/test_server/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def test_eval_model_zip(pybio_model_bytes, cache_path):
with ZipFile(pybio_model_bytes) as zf:
exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path)
assert isinstance(exemplum, Exemplum)

@pytest.mark.xfail
def test_eval_tensorflow_model_zip(pybio_dummy_tensorflow_model_bytes, cache_path):
with ZipFile(pybio_dummy_tensorflow_model_bytes) as zf:
exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path)
assert isinstance(exemplum, Exemplum)
11 changes: 8 additions & 3 deletions tiktorch/server/exemplum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ def __init__(
pybio_model: nodes.Model,
batch_size: int = 1,
num_iterations_per_update: int = 2,
_devices=Sequence[torch.device],
_devices=Sequence[str],
):
self.max_num_iterations = 0
self.iteration_count = 0
self.devices = _devices
spec = pybio_model.spec
self.name = spec.name

Expand Down Expand Up @@ -89,12 +88,18 @@ def __init__(
self.halo = list(zip(self.output_axes, _halo))

self.model = get_instance(pybio_model)
self.model.to(self.devices[0])
if spec.framework == "pytorch":
self.devices = [torch.device(d) for d in _devices]
self.model.to(self.devices[0])
assert isinstance(self.model, torch.nn.Module)
if spec.prediction.weights is not None:
state = torch.load(spec.prediction.weights.source, map_location=self.devices[0])
self.model.load_state_dict(state)
# elif spec.framework == "tensorflow":
# import tensorflow as tf
# self.devices = []
# tf_model = tf.keras.models.load_model(spec.prediction.weights.source)
# self.model.set_model(tf_model)
else:
raise NotImplementedError

Expand Down
1 change: 0 additions & 1 deletion tiktorch/server/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def eval_model_zip(model_zip: ZipFile, devices: Sequence[str], cache_path: Optio

pybio_model = spec.utils.load_model(spec_file_str, root_path=temp_path, cache_path=cache_path)

devices = [torch.device(d) for d in devices]
if pybio_model.spec.training is None:
return Exemplum(pybio_model=pybio_model, _devices=devices)
else:
Expand Down

0 comments on commit 3ed0d8c

Please sign in to comment.