Skip to content

Commit

Permalink
fix(dataset): Support bytes directly in read_content
Browse files Browse the repository at this point in the history
  • Loading branch information
qgerome committed Jan 12, 2024
1 parent cba3e7e commit 8c886ae
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
23 changes: 11 additions & 12 deletions openhexa/sdk/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ def get_file(self, filename: str) -> DatasetFile:

def add_file(
self,
source: typing.Union[str, PathLike[str], typing.IO],
source: typing.Union[str, PathLike[str], typing.IO, bytes],
filename: typing.Optional[str] = None,
) -> DatasetFile:
"""Create a new dataset file and add it to the dataset version."""
mime_type = None
if isinstance(source, (str, PathLike)):
path = Path(source)
filename = path.name
filename = path.name if filename is None else filename
mime_type, _ = mimetypes.guess_type(path)
else:
if filename is None:
Expand Down Expand Up @@ -249,19 +249,18 @@ def add_file(
raise ValueError("A file with this name already exists in this dataset version")
else:
raise Exception(errors)

upload_url = data["createDatasetVersionFile"]["uploadUrl"]
content = read_content(source)

response = requests.put(upload_url, data=content, headers={"Content-Type": mime_type})
result = data["createDatasetVersionFile"]
upload_url = result["uploadUrl"]
with read_content(source) as content:
response = requests.put(upload_url, data=content, headers={"Content-Type": mime_type})
response.raise_for_status()
return DatasetFile(
version=self,
id=data["createDatasetVersionFile"]["file"]["id"],
filename=data["createDatasetVersionFile"]["file"]["filename"],
content_type=data["createDatasetVersionFile"]["file"]["contentType"],
uri=data["createDatasetVersionFile"]["file"]["uri"],
created_at=data["createDatasetVersionFile"]["file"]["createdAt"],
id=result["file"]["id"],
filename=result["file"]["filename"],
content_type=result["file"]["contentType"],
uri=result["file"]["uri"],
created_at=result["file"]["createdAt"],
)


Expand Down
30 changes: 16 additions & 14 deletions openhexa/sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Miscellaneous utility functions."""

import abc
import contextlib
import enum
import os
import typing
Expand Down Expand Up @@ -173,20 +174,21 @@ def __next__(self):
return result


def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> bytes:
@contextlib.contextmanager
def read_content(source: typing.Union[str, os.PathLike[str], typing.IO, bytes]):
"""Read file content and return it as bytes."""
# If source is a string or PathLike object
if isinstance(source, (str, os.PathLike)):
with open(os.fspath(source), "rb") as f:
return f.read()

# If source is a buffer
elif hasattr(source, "read"):
content = source.read()
try:
if isinstance(source, bytes):
yield source
elif hasattr(source, "read"):
yield source
# If source is a string or PathLike object
elif isinstance(source, (str, os.PathLike)):
with open(os.fspath(source), "rb") as f:
yield f

if not isinstance(content, bytes):
return content.encode(encoding)
else:
return content

raise ValueError("Unsupported type for source")
raise ValueError("Unsupported type for source")
finally:
if hasattr(source, "close"):
source.close()

0 comments on commit 8c886ae

Please sign in to comment.