From 8092b1758572d8c6895bc427a1054c7ae1c192d0 Mon Sep 17 00:00:00 2001 From: Mikhail Sveshnikov Date: Mon, 5 Dec 2022 12:21:39 +0200 Subject: [PATCH] Client files (#508) Allow to pass filenames for client invocations if endpoint has binary serializer Need some tests --- mlem/core/data_type.py | 7 +++++++ mlem/runtime/client.py | 26 ++++++++++++++++++++++---- tests/contrib/test_fastapi.py | 14 +++++++++++--- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 7641e002..7fa7f9e8 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -182,6 +182,10 @@ def get_model(self, prefix: str = "") -> Union[Type[BaseModel], type]: raise NotImplementedError return self.serializer.data.get_model(self.data_type, prefix) + @property + def support_files(self): + return self.serializer.binary.support_files + class DefaultDataTypeSerializer(DataTypeSerializer): @validator("data_type") @@ -217,6 +221,8 @@ def get_model( class BinarySerializer(Serializer[DT], Generic[DT], ABC): """Base class for serializers from/to raw binary data""" + support_files: ClassVar[bool] = False + @abstractmethod def serialize(self, data_type: DT, instance: Any) -> bytes: raise NotImplementedError @@ -1080,6 +1086,7 @@ class FileSerializer(BinarySerializer): """BinarySerialzier for arbitrary data using reader and writer""" type: ClassVar = "file" + support_files: ClassVar = True @staticmethod def _get_artifact(data_type: DataType, instance: Any) -> InMemoryArtifact: diff --git a/mlem/runtime/client.py b/mlem/runtime/client.py index dc7fc8db..5bff6c4c 100644 --- a/mlem/runtime/client.py +++ b/mlem/runtime/client.py @@ -1,4 +1,5 @@ import logging +import os from abc import ABC, abstractmethod from typing import Any, BinaryIO, Callable, ClassVar, Optional @@ -108,10 +109,27 @@ def __call__(self, *args, **kwargs): raise NotImplementedError( "Multiple file requests are not supported yet" ) - with serializer.dump(obj) as f: - return self.method.returns.get_serializer().deserialize( - self.call_method_binary(self.name, f, return_raw) - ) + if ( + isinstance(obj, (str, os.PathLike)) + and serializer.support_files + ): + with open(obj, "rb") as f: + return ( + self.method.returns.get_serializer().deserialize( + self.call_method_binary( + self.name, f, return_raw + ) + ) + ) + else: + with serializer.dump(obj) as f: + return ( + self.method.returns.get_serializer().deserialize( + self.call_method_binary( + self.name, f, return_raw + ) + ) + ) data[arg.name] = serializer.serialize(obj) diff --git a/tests/contrib/test_fastapi.py b/tests/contrib/test_fastapi.py index 7bc23833..aa714847 100644 --- a/tests/contrib/test_fastapi.py +++ b/tests/contrib/test_fastapi.py @@ -191,10 +191,11 @@ def test_nested_objects_in_schema(data): ], ) def test_file_endpoint( - create_mlem_client, create_client, data, eq_assert: Callable + create_mlem_client, create_client, data, eq_assert: Callable, tmp_path ): - model = MlemModel.from_obj(lambda x: x, sample_data=data) - model_interface = ModelInterface.from_model(model) + model_interface = ModelInterface.from_model( + MlemModel.from_obj(lambda x: x, sample_data=data) + ) server = FastAPIServer( standardize=False, @@ -218,3 +219,10 @@ def test_file_endpoint( eq_assert(resp_array, data) eq_assert(mlem_client(data), data) + + path = tmp_path / "data" + with open(path, "wb") as fout, ser.dump(dt, data) as fin: + fout.write(fin.read()) + + eq_assert(mlem_client(str(path)), data) + eq_assert(mlem_client(path), data)