-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The package contains code for uploading files to Neptune.
- Loading branch information
1 parent
022ce35
commit 53b8004
Showing
4 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""This subpackage contains code for syncing files with Neptune. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import multiprocessing | ||
import pathlib | ||
from typing import ( | ||
NamedTuple, | ||
Optional, | ||
) | ||
|
||
from neptune_scale.util import SharedInt | ||
from neptune_scale.util.abstract import Resource | ||
|
||
|
||
class UploadMessage(NamedTuple): | ||
attribute_path: str | ||
local_path: pathlib.Path | ||
target_path: Optional[str] | ||
target_basename: Optional[str] | ||
|
||
|
||
class FileUploadQueue(Resource): | ||
"""Queue for submitting file upload requests. Shared between the main process and FiledUploadWorkerThread, which | ||
is spawned in the worker process.""" | ||
|
||
def __init__(self) -> None: | ||
self._queue: multiprocessing.Queue[UploadMessage] = multiprocessing.Queue(maxsize=4096) | ||
self._active_uploads = SharedInt(0) | ||
|
||
@property | ||
def active_uploads(self) -> int: | ||
"""Returns the number of currently active uploads.""" | ||
with self._active_uploads: | ||
return self._active_uploads.value | ||
|
||
# Main process API | ||
def submit( | ||
self, | ||
*, | ||
attribute_path: str, | ||
local_path: pathlib.Path, | ||
target_path: Optional[str], | ||
target_basename: Optional[str], | ||
) -> None: | ||
with self._active_uploads: | ||
self._queue.put(UploadMessage(attribute_path, local_path, target_path, target_basename)) | ||
self._active_uploads.value += 1 | ||
|
||
def wait_for_completion(self, timeout: Optional[float] = None) -> None: | ||
with self._active_uploads: | ||
self._active_uploads.wait_for(lambda: self._active_uploads.value == 0, timeout=timeout) | ||
|
||
def close(self) -> None: | ||
self._queue.close() | ||
self._queue.cancel_join_thread() | ||
|
||
# Worker process API | ||
def decrement_active(self) -> None: | ||
with self._active_uploads: | ||
self._active_uploads.value -= 1 | ||
assert self._active_uploads.value >= 0 | ||
self._active_uploads.notify_all() | ||
|
||
def get(self, timeout: float) -> UploadMessage: | ||
return self._queue.get(timeout=timeout) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import time | ||
import uuid | ||
from concurrent import futures | ||
from pathlib import Path | ||
from queue import Empty | ||
from typing import ( | ||
Callable, | ||
Optional, | ||
) | ||
|
||
from neptune_scale.exceptions import NeptuneScaleError | ||
from neptune_scale.sync.errors_tracking import ErrorsQueue | ||
from neptune_scale.sync.files.queue import ( | ||
FileUploadQueue, | ||
UploadMessage, | ||
) | ||
from neptune_scale.util import ( | ||
Daemon, | ||
get_logger, | ||
) | ||
from neptune_scale.util.abstract import Resource | ||
|
||
logger = get_logger() | ||
|
||
|
||
class FileUploadWorkerThread(Daemon, Resource): | ||
"""Consumes messages from the provided FileUploadQueue and performs the upload operation | ||
in a pool of worker threads. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
project: str, | ||
run_id: str, | ||
api_token: str, | ||
family: str, | ||
input_queue: FileUploadQueue, | ||
errors_queue: ErrorsQueue, | ||
) -> None: | ||
super().__init__(sleep_time=0.5, name="FileUploader") | ||
|
||
self._project = project | ||
self._run_id = run_id | ||
self._api_token = api_token | ||
self._family = family | ||
self._input_queue = input_queue | ||
self._errors_queue = errors_queue | ||
self._executor = futures.ThreadPoolExecutor() | ||
|
||
def work(self) -> None: | ||
while True: | ||
try: | ||
msg = self._input_queue.get(timeout=0.5) | ||
except Empty: | ||
return | ||
|
||
try: | ||
future = self._executor.submit( | ||
self._do_upload, msg.attribute_path, msg.local_path, msg.target_path, msg.target_basename | ||
) | ||
future.add_done_callback(self._make_done_callback(msg)) | ||
except Exception as e: | ||
logger.error(f"Failed to submit file upload task for `{msg.local_path}` as `{msg.attribute_path}`: {e}") | ||
self._errors_queue.put(e) | ||
|
||
def close(self) -> None: | ||
self._executor.shutdown() | ||
|
||
def _do_upload( | ||
self, | ||
attribute_path: str, | ||
local_path: Path, | ||
target_path: Optional[str], | ||
target_basename: Optional[str], | ||
) -> None: | ||
path = determine_path(self._run_id, attribute_path, local_path, target_path, target_basename) | ||
|
||
try: | ||
url = self._request_upload_url(attribute_path, path) | ||
upload_file(local_path, url) | ||
self._finalize_upload(path) | ||
except Exception as e: | ||
self._finalize_upload(path, e) | ||
raise e | ||
|
||
def _request_upload_url(self, attribute_path: str, file_path: str) -> str: | ||
assert self._api_token | ||
# TODO: Make this retryable | ||
return ".".join(["http://localhost:8012/", attribute_path, file_path]) | ||
|
||
def _finalize_upload(self, attribute_path: str, error: Optional[Exception] = None) -> None: | ||
# TODO: hit the backend | ||
# TODO: needs to be retryable | ||
print(f"finalizing file {attribute_path}") | ||
time.sleep(1) | ||
print(f"finalized file {attribute_path}") | ||
|
||
def _make_done_callback(self, message: UploadMessage) -> Callable[[futures.Future], None]: | ||
"""Returns a callback function suitable for use with Future.add_done_callback(). Decreases the active upload | ||
count and propagates any exception to the errors queue. | ||
""" | ||
|
||
def _on_task_completed(future: futures.Future) -> None: | ||
self._input_queue.decrement_active() | ||
|
||
exc = future.exception() | ||
if future.cancelled() and exc is None: | ||
exc = NeptuneScaleError("Operation cancelled") | ||
|
||
if exc: | ||
logger.error(f"Failed to upload file `{message.local_path}` as `{message.attribute_path}`: {exc}") | ||
self._errors_queue.put(exc) | ||
|
||
return _on_task_completed | ||
|
||
|
||
def determine_path( | ||
run_id: str, attribute_path: str, local_path: Path, target_path: Optional[str], target_basename: Optional[str] | ||
) -> str: | ||
# Target path always takes precedence as-is | ||
if target_path: | ||
return target_path | ||
|
||
if target_basename: | ||
parts: tuple[str, ...] = (run_id, attribute_path, target_basename) | ||
else: | ||
parts = (run_id, attribute_path, str(uuid.uuid4()), local_path.name) | ||
|
||
return "/".join(parts) | ||
|
||
|
||
def upload_file(local_path: Path, url: str) -> None: | ||
# TODO: do the actual work :) | ||
assert local_path and url | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
from pytest import mark | ||
|
||
from neptune_scale.sync.files.worker import determine_path | ||
|
||
|
||
@mark.parametrize( | ||
"local, full, basename, expected", | ||
( | ||
("some/file.py", None, None, "RUN/ATTR/UUID4/file.py"), | ||
("some/file.py", None, "file.txt", "RUN/ATTR/file.txt"), | ||
("some/file.py", "full/path.txt", None, "full/path.txt"), | ||
("some/file.py", "full/path.txt", "basename", "full/path.txt"), | ||
), | ||
) | ||
def test_determine_path(local, full, basename, expected): | ||
with patch("uuid.uuid4", return_value="UUID4"): | ||
assert determine_path("RUN", "ATTR", Path(local), full, basename) == expected |