Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Client files (#508)
Browse files Browse the repository at this point in the history
Allow to pass filenames for client invocations if endpoint has binary
serializer

Need some tests
  • Loading branch information
mike0sv authored Dec 5, 2022
1 parent de77417 commit 8092b17
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
7 changes: 7 additions & 0 deletions mlem/core/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def get_model(self, prefix: str = "") -> Union[Type[BaseModel], type]:
raise NotImplementedError
return self.serializer.data.get_model(self.data_type, prefix)

@property
def support_files(self):
return self.serializer.binary.support_files


class DefaultDataTypeSerializer(DataTypeSerializer):
@validator("data_type")
Expand Down Expand Up @@ -217,6 +221,8 @@ def get_model(
class BinarySerializer(Serializer[DT], Generic[DT], ABC):
"""Base class for serializers from/to raw binary data"""

support_files: ClassVar[bool] = False

@abstractmethod
def serialize(self, data_type: DT, instance: Any) -> bytes:
raise NotImplementedError
Expand Down Expand Up @@ -1080,6 +1086,7 @@ class FileSerializer(BinarySerializer):
"""BinarySerialzier for arbitrary data using reader and writer"""

type: ClassVar = "file"
support_files: ClassVar = True

@staticmethod
def _get_artifact(data_type: DataType, instance: Any) -> InMemoryArtifact:
Expand Down
26 changes: 22 additions & 4 deletions mlem/runtime/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, BinaryIO, Callable, ClassVar, Optional

Expand Down Expand Up @@ -108,10 +109,27 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError(
"Multiple file requests are not supported yet"
)
with serializer.dump(obj) as f:
return self.method.returns.get_serializer().deserialize(
self.call_method_binary(self.name, f, return_raw)
)
if (
isinstance(obj, (str, os.PathLike))
and serializer.support_files
):
with open(obj, "rb") as f:
return (
self.method.returns.get_serializer().deserialize(
self.call_method_binary(
self.name, f, return_raw
)
)
)
else:
with serializer.dump(obj) as f:
return (
self.method.returns.get_serializer().deserialize(
self.call_method_binary(
self.name, f, return_raw
)
)
)

data[arg.name] = serializer.serialize(obj)

Expand Down
14 changes: 11 additions & 3 deletions tests/contrib/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ def test_nested_objects_in_schema(data):
],
)
def test_file_endpoint(
create_mlem_client, create_client, data, eq_assert: Callable
create_mlem_client, create_client, data, eq_assert: Callable, tmp_path
):
model = MlemModel.from_obj(lambda x: x, sample_data=data)
model_interface = ModelInterface.from_model(model)
model_interface = ModelInterface.from_model(
MlemModel.from_obj(lambda x: x, sample_data=data)
)

server = FastAPIServer(
standardize=False,
Expand All @@ -218,3 +219,10 @@ def test_file_endpoint(
eq_assert(resp_array, data)

eq_assert(mlem_client(data), data)

path = tmp_path / "data"
with open(path, "wb") as fout, ser.dump(dt, data) as fin:
fout.write(fin.read())

eq_assert(mlem_client(str(path)), data)
eq_assert(mlem_client(path), data)

0 comments on commit 8092b17

Please sign in to comment.