diff --git a/openhexa/sdk/datasets/dataset.py b/openhexa/sdk/datasets/dataset.py index f22b87c..6fddbdb 100644 --- a/openhexa/sdk/datasets/dataset.py +++ b/openhexa/sdk/datasets/dataset.py @@ -203,14 +203,14 @@ def get_file(self, filename: str) -> DatasetFile: def add_file( self, - source: typing.Union[str, PathLike[str], typing.IO], + source: typing.Union[str, PathLike[str], typing.IO, bytes], filename: typing.Optional[str] = None, ) -> DatasetFile: """Create a new dataset file and add it to the dataset version.""" mime_type = None if isinstance(source, (str, PathLike)): path = Path(source) - filename = path.name + filename = path.name if filename is None else filename mime_type, _ = mimetypes.guess_type(path) else: if filename is None: @@ -249,19 +249,18 @@ def add_file( raise ValueError("A file with this name already exists in this dataset version") else: raise Exception(errors) - - upload_url = data["createDatasetVersionFile"]["uploadUrl"] - content = read_content(source) - - response = requests.put(upload_url, data=content, headers={"Content-Type": mime_type}) + result = data["createDatasetVersionFile"] + upload_url = result["uploadUrl"] + with read_content(source) as content: + response = requests.put(upload_url, data=content, headers={"Content-Type": mime_type}) response.raise_for_status() return DatasetFile( version=self, - id=data["createDatasetVersionFile"]["file"]["id"], - filename=data["createDatasetVersionFile"]["file"]["filename"], - content_type=data["createDatasetVersionFile"]["file"]["contentType"], - uri=data["createDatasetVersionFile"]["file"]["uri"], - created_at=data["createDatasetVersionFile"]["file"]["createdAt"], + id=result["file"]["id"], + filename=result["file"]["filename"], + content_type=result["file"]["contentType"], + uri=result["file"]["uri"], + created_at=result["file"]["createdAt"], ) diff --git a/openhexa/sdk/utils.py b/openhexa/sdk/utils.py index b569397..e651607 100644 --- a/openhexa/sdk/utils.py +++ b/openhexa/sdk/utils.py @@ -1,6 +1,7 @@ """Miscellaneous utility functions.""" import abc +import contextlib import enum import os import typing @@ -173,20 +174,21 @@ def __next__(self): return result -def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> bytes: +@contextlib.contextmanager +def read_content(source: typing.Union[str, os.PathLike[str], typing.IO, bytes]): """Read file content and return it as bytes.""" - # If source is a string or PathLike object - if isinstance(source, (str, os.PathLike)): - with open(os.fspath(source), "rb") as f: - return f.read() - - # If source is a buffer - elif hasattr(source, "read"): - content = source.read() + try: + if isinstance(source, bytes): + yield source + elif hasattr(source, "read"): + yield source + # If source is a string or PathLike object + elif isinstance(source, (str, os.PathLike)): + with open(os.fspath(source), "rb") as f: + yield f - if not isinstance(content, bytes): - return content.encode(encoding) else: - return content - - raise ValueError("Unsupported type for source") + raise ValueError("Unsupported type for source") + finally: + if hasattr(source, "close"): + source.close()