From 53b80045ade008f9067e4261959496318f644aef Mon Sep 17 00:00:00 2001
From: Krzysztof Godlewski <krzysztof.godlewski@neptune.ai>
Date: Thu, 12 Dec 2024 11:01:14 +0100
Subject: [PATCH] Add `.sync.files` subpackage

The package contains code for uploading files to Neptune.
---
 src/neptune_scale/sync/files/__init__.py |   2 +
 src/neptune_scale/sync/files/queue.py    |  62 +++++++++++
 src/neptune_scale/sync/files/worker.py   | 136 +++++++++++++++++++++++
 tests/unit/test_file_upload.py           |  20 ++++
 4 files changed, 220 insertions(+)
 create mode 100644 src/neptune_scale/sync/files/__init__.py
 create mode 100644 src/neptune_scale/sync/files/queue.py
 create mode 100644 src/neptune_scale/sync/files/worker.py
 create mode 100644 tests/unit/test_file_upload.py

diff --git a/src/neptune_scale/sync/files/__init__.py b/src/neptune_scale/sync/files/__init__.py
new file mode 100644
index 0000000..4052948
--- /dev/null
+++ b/src/neptune_scale/sync/files/__init__.py
@@ -0,0 +1,2 @@
+"""This subpackage contains code for syncing files with Neptune.
+"""
diff --git a/src/neptune_scale/sync/files/queue.py b/src/neptune_scale/sync/files/queue.py
new file mode 100644
index 0000000..3854ee7
--- /dev/null
+++ b/src/neptune_scale/sync/files/queue.py
@@ -0,0 +1,62 @@
+import multiprocessing
+import pathlib
+from typing import (
+    NamedTuple,
+    Optional,
+)
+
+from neptune_scale.util import SharedInt
+from neptune_scale.util.abstract import Resource
+
+
+class UploadMessage(NamedTuple):
+    attribute_path: str
+    local_path: pathlib.Path
+    target_path: Optional[str]
+    target_basename: Optional[str]
+
+
+class FileUploadQueue(Resource):
+    """Queue for submitting file upload requests. Shared between the main process and FiledUploadWorkerThread, which
+    is spawned in the worker process."""
+
+    def __init__(self) -> None:
+        self._queue: multiprocessing.Queue[UploadMessage] = multiprocessing.Queue(maxsize=4096)
+        self._active_uploads = SharedInt(0)
+
+    @property
+    def active_uploads(self) -> int:
+        """Returns the number of currently active uploads."""
+        with self._active_uploads:
+            return self._active_uploads.value
+
+    # Main process API
+    def submit(
+        self,
+        *,
+        attribute_path: str,
+        local_path: pathlib.Path,
+        target_path: Optional[str],
+        target_basename: Optional[str],
+    ) -> None:
+        with self._active_uploads:
+            self._queue.put(UploadMessage(attribute_path, local_path, target_path, target_basename))
+            self._active_uploads.value += 1
+
+    def wait_for_completion(self, timeout: Optional[float] = None) -> None:
+        with self._active_uploads:
+            self._active_uploads.wait_for(lambda: self._active_uploads.value == 0, timeout=timeout)
+
+    def close(self) -> None:
+        self._queue.close()
+        self._queue.cancel_join_thread()
+
+    # Worker process API
+    def decrement_active(self) -> None:
+        with self._active_uploads:
+            self._active_uploads.value -= 1
+            assert self._active_uploads.value >= 0
+            self._active_uploads.notify_all()
+
+    def get(self, timeout: float) -> UploadMessage:
+        return self._queue.get(timeout=timeout)
diff --git a/src/neptune_scale/sync/files/worker.py b/src/neptune_scale/sync/files/worker.py
new file mode 100644
index 0000000..e79d16c
--- /dev/null
+++ b/src/neptune_scale/sync/files/worker.py
@@ -0,0 +1,136 @@
+import time
+import uuid
+from concurrent import futures
+from pathlib import Path
+from queue import Empty
+from typing import (
+    Callable,
+    Optional,
+)
+
+from neptune_scale.exceptions import NeptuneScaleError
+from neptune_scale.sync.errors_tracking import ErrorsQueue
+from neptune_scale.sync.files.queue import (
+    FileUploadQueue,
+    UploadMessage,
+)
+from neptune_scale.util import (
+    Daemon,
+    get_logger,
+)
+from neptune_scale.util.abstract import Resource
+
+logger = get_logger()
+
+
+class FileUploadWorkerThread(Daemon, Resource):
+    """Consumes messages from the provided FileUploadQueue and performs the upload operation
+    in a pool of worker threads.
+    """
+
+    def __init__(
+        self,
+        *,
+        project: str,
+        run_id: str,
+        api_token: str,
+        family: str,
+        input_queue: FileUploadQueue,
+        errors_queue: ErrorsQueue,
+    ) -> None:
+        super().__init__(sleep_time=0.5, name="FileUploader")
+
+        self._project = project
+        self._run_id = run_id
+        self._api_token = api_token
+        self._family = family
+        self._input_queue = input_queue
+        self._errors_queue = errors_queue
+        self._executor = futures.ThreadPoolExecutor()
+
+    def work(self) -> None:
+        while True:
+            try:
+                msg = self._input_queue.get(timeout=0.5)
+            except Empty:
+                return
+
+            try:
+                future = self._executor.submit(
+                    self._do_upload, msg.attribute_path, msg.local_path, msg.target_path, msg.target_basename
+                )
+                future.add_done_callback(self._make_done_callback(msg))
+            except Exception as e:
+                logger.error(f"Failed to submit file upload task for `{msg.local_path}` as `{msg.attribute_path}`: {e}")
+                self._errors_queue.put(e)
+
+    def close(self) -> None:
+        self._executor.shutdown()
+
+    def _do_upload(
+        self,
+        attribute_path: str,
+        local_path: Path,
+        target_path: Optional[str],
+        target_basename: Optional[str],
+    ) -> None:
+        path = determine_path(self._run_id, attribute_path, local_path, target_path, target_basename)
+
+        try:
+            url = self._request_upload_url(attribute_path, path)
+            upload_file(local_path, url)
+            self._finalize_upload(path)
+        except Exception as e:
+            self._finalize_upload(path, e)
+            raise e
+
+    def _request_upload_url(self, attribute_path: str, file_path: str) -> str:
+        assert self._api_token
+        # TODO: Make this retryable
+        return ".".join(["http://localhost:8012/", attribute_path, file_path])
+
+    def _finalize_upload(self, attribute_path: str, error: Optional[Exception] = None) -> None:
+        # TODO: hit the backend
+        # TODO: needs to be retryable
+        print(f"finalizing file {attribute_path}")
+        time.sleep(1)
+        print(f"finalized file {attribute_path}")
+
+    def _make_done_callback(self, message: UploadMessage) -> Callable[[futures.Future], None]:
+        """Returns a callback function suitable for use with Future.add_done_callback(). Decreases the active upload
+        count and propagates any exception to the errors queue.
+        """
+
+        def _on_task_completed(future: futures.Future) -> None:
+            self._input_queue.decrement_active()
+
+            exc = future.exception()
+            if future.cancelled() and exc is None:
+                exc = NeptuneScaleError("Operation cancelled")
+
+            if exc:
+                logger.error(f"Failed to upload file `{message.local_path}` as `{message.attribute_path}`: {exc}")
+                self._errors_queue.put(exc)
+
+        return _on_task_completed
+
+
+def determine_path(
+    run_id: str, attribute_path: str, local_path: Path, target_path: Optional[str], target_basename: Optional[str]
+) -> str:
+    # Target path always takes precedence as-is
+    if target_path:
+        return target_path
+
+    if target_basename:
+        parts: tuple[str, ...] = (run_id, attribute_path, target_basename)
+    else:
+        parts = (run_id, attribute_path, str(uuid.uuid4()), local_path.name)
+
+    return "/".join(parts)
+
+
+def upload_file(local_path: Path, url: str) -> None:
+    # TODO: do the actual work :)
+    assert local_path and url
+    pass
diff --git a/tests/unit/test_file_upload.py b/tests/unit/test_file_upload.py
new file mode 100644
index 0000000..a2cc1c6
--- /dev/null
+++ b/tests/unit/test_file_upload.py
@@ -0,0 +1,20 @@
+from pathlib import Path
+from unittest.mock import patch
+
+from pytest import mark
+
+from neptune_scale.sync.files.worker import determine_path
+
+
+@mark.parametrize(
+    "local, full, basename, expected",
+    (
+        ("some/file.py", None, None, "RUN/ATTR/UUID4/file.py"),
+        ("some/file.py", None, "file.txt", "RUN/ATTR/file.txt"),
+        ("some/file.py", "full/path.txt", None, "full/path.txt"),
+        ("some/file.py", "full/path.txt", "basename", "full/path.txt"),
+    ),
+)
+def test_determine_path(local, full, basename, expected):
+    with patch("uuid.uuid4", return_value="UUID4"):
+        assert determine_path("RUN", "ATTR", Path(local), full, basename) == expected