From 6cf67159109e055c14be3b5868a27cf6dbc9b9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Mon, 29 Jul 2024 10:10:59 +0200 Subject: [PATCH 01/10] Added minimal Run classes (#6) --- .github/actions/install-package/action.yml | 28 ++++++++ .github/actions/test-unit/action.yml | 47 +++++++++++++ .github/workflows/unit-in-pull-request.yml | 30 ++++++++ .github/workflows/unit.yml | 35 +++++++++ CHANGELOG.md | 12 +++- dev_requirements.txt | 2 + src/neptune_scale/__init__.py | 73 +++++++++++++++++++ src/neptune_scale/core/__init__.py | 0 src/neptune_scale/core/validation.py | 48 +++++++++++++ src/neptune_scale/parameters.py | 2 + tests/unit/test_run.py | 82 ++++++++++++++++++++++ 11 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 .github/actions/install-package/action.yml create mode 100644 .github/actions/test-unit/action.yml create mode 100644 .github/workflows/unit-in-pull-request.yml create mode 100644 .github/workflows/unit.yml create mode 100644 src/neptune_scale/core/__init__.py create mode 100644 src/neptune_scale/core/validation.py create mode 100644 src/neptune_scale/parameters.py create mode 100644 tests/unit/test_run.py diff --git a/.github/actions/install-package/action.yml b/.github/actions/install-package/action.yml new file mode 100644 index 00000000..050c8b59 --- /dev/null +++ b/.github/actions/install-package/action.yml @@ -0,0 +1,28 @@ +--- +name: Package +description: Install python and package +inputs: + python-version: + description: "Python version" + required: true + os: + description: "Operating system" + required: true + +runs: + using: "composite" + steps: + - name: Install Python ${{ inputs.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ inputs.python-version }} + + - name: Install dependencies + run: | + pip install -r dev_requirements.txt + shell: bash + + - name: List dependencies + run: | + pip list + shell: bash diff --git a/.github/actions/test-unit/action.yml b/.github/actions/test-unit/action.yml new file mode 100644 index 00000000..b3c5c79c --- /dev/null +++ b/.github/actions/test-unit/action.yml @@ -0,0 +1,47 @@ +--- +name: Test Unit +description: Check unit tests +inputs: + python-version: + description: "Python version" + required: true + os: + description: "Operating system" + required: true + report_job: + description: "Job name to update by JUnit report" + required: true + +runs: + using: "composite" + steps: + - name: Install package + uses: ./.github/actions/install-package + with: + python-version: ${{ inputs.python-version }} + os: ${{ inputs.os }}-latest + + - name: Test + run: | + pytest -v ./tests/unit/ \ + --timeout=120 --timeout_method=thread \ + --color=yes \ + --junitxml="./test-results/test-unit-new-${{ inputs.os }}-${{ inputs.python-version }}.xml" + shell: bash + + - name: Upload test reports + uses: actions/upload-artifact@v3 + if: always() + with: + name: test-artifacts + path: ./test-results + + - name: Report + uses: mikepenz/action-junit-report@v3.6.2 + if: always() + with: + report_paths: './test-results/test-unit-*.xml' + update_check: true + include_passed: true + annotate_notice: true + job_name: ${{ inputs.report_job }} diff --git a/.github/workflows/unit-in-pull-request.yml b/.github/workflows/unit-in-pull-request.yml new file mode 100644 index 00000000..efa0a6a1 --- /dev/null +++ b/.github/workflows/unit-in-pull-request.yml @@ -0,0 +1,30 @@ +name: Unittests + +on: + workflow_dispatch: + push: + branches-ignore: + - main + +jobs: + test: + timeout-minutes: 75 + strategy: + fail-fast: false + matrix: + os: [ubuntu, windows, macos] + python-version: ["3.8"] + name: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' + runs-on: ${{ matrix.os }}-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Run tests + uses: ./.github/actions/test-unit + with: + python-version: ${{ matrix.python-version }} + os: ${{ matrix.os }} + report_job: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml new file mode 100644 index 00000000..efff71a4 --- /dev/null +++ b/.github/workflows/unit.yml @@ -0,0 +1,35 @@ +name: unit + +on: + workflow_call: + workflow_dispatch: + schedule: + - cron: "0 4 * * *" # Run every day at arbitrary time (4:00 AM UTC) + push: + branches: + - main + +jobs: + test: + timeout-minutes: 75 + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + os: [ubuntu, windows, macos] + + name: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' + runs-on: ${{ matrix.os }}-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.client_payload.pull_request.head.ref }} + + - name: Run tests + uses: ./.github/actions/test-unit + with: + python-version: ${{ matrix.python-version }} + os: ${{ matrix.os }} + report_job: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' diff --git a/CHANGELOG.md b/CHANGELOG.md index 801acbe6..c2a541fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,11 @@ -## [UNRELEASED] neptune-client-scale 0.1.0 +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) diff --git a/dev_requirements.txt b/dev_requirements.txt index a19f8a59..7b8b416f 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,3 +2,5 @@ # dev pre-commit +pytest +pytest-timeout diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index e69de29b..ed4b0929 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -0,0 +1,73 @@ +""" +Python package +""" + +from __future__ import annotations + +__all__ = ["Run"] + +from contextlib import AbstractContextManager +from types import TracebackType + +from neptune_scale.core.validation import ( + verify_max_length, + verify_non_empty, + verify_project_qualified_name, + verify_type, +) +from neptune_scale.parameters import ( + MAX_FAMILY_LENGTH, + MAX_RUN_ID_LENGTH, +) + + +class Run(AbstractContextManager): + """ + Representation of tracked metadata. + """ + + def __init__(self, *, project: str, api_token: str, family: str, run_id: str) -> None: + """ + Initializes a run that logs the model-building metadata to Neptune. + + Args: + project: Name of the project where the metadata is logged, in the form `workspace-name/project-name`. + api_token: Your Neptune API token. + family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy. + Max length: 128 characters. + run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters. + """ + verify_type("api_token", api_token, str) + verify_type("family", family, str) + verify_type("run_id", run_id, str) + + verify_non_empty("api_token", api_token) + verify_non_empty("family", family) + verify_non_empty("run_id", run_id) + + verify_project_qualified_name("project", project) + + verify_max_length("family", family, MAX_FAMILY_LENGTH) + verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH) + + self._project: str = project + self._api_token: str = api_token + self._family: str = family + self._run_id: str = run_id + + def __enter__(self) -> Run: + return self + + def close(self) -> None: + """ + Stops the connection to Neptune and synchronizes all data. + """ + pass + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() diff --git a/src/neptune_scale/core/__init__.py b/src/neptune_scale/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py new file mode 100644 index 00000000..efd8716a --- /dev/null +++ b/src/neptune_scale/core/validation.py @@ -0,0 +1,48 @@ +__all__ = ( + "verify_type", + "verify_non_empty", + "verify_max_length", + "verify_project_qualified_name", +) + +from typing import ( + Any, + Union, +) + + +def get_type_name(var_type: Union[type, tuple]) -> str: + return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) + + +def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None: + try: + if isinstance(expected_type, tuple): + type_name = " or ".join(get_type_name(t) for t in expected_type) + else: + type_name = get_type_name(expected_type) + except Exception as e: + # Just to be sure that nothing weird will be raised here + raise TypeError(f"Incorrect type of {var_name}") from e + + if not isinstance(var, expected_type): + raise TypeError(f"{var_name} must be a {type_name} (was {type(var)})") + + +def verify_non_empty(var_name: str, var: Any) -> None: + if not var: + raise ValueError(f"{var_name} must not be empty") + + +def verify_max_length(var_name: str, var: Any, max_length: int) -> None: + if len(var) > max_length: + raise ValueError(f"{var_name} must not exceed {max_length} characters") + + +def verify_project_qualified_name(var_name: str, var: Any) -> None: + verify_type(var_name, var, str) + verify_non_empty(var_name, var) + + project_parts = var.split("/") + if len(project_parts) != 2: + raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") diff --git a/src/neptune_scale/parameters.py b/src/neptune_scale/parameters.py new file mode 100644 index 00000000..44444d86 --- /dev/null +++ b/src/neptune_scale/parameters.py @@ -0,0 +1,2 @@ +MAX_RUN_ID_LENGTH = 128 +MAX_FAMILY_LENGTH = 128 diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py new file mode 100644 index 00000000..f7d8310c --- /dev/null +++ b/tests/unit/test_run.py @@ -0,0 +1,82 @@ +import uuid + +import pytest + +from neptune_scale import Run + + +def test_context_manager(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + # then + assert True + + +def test_close(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # and + run = Run(project=project, api_token=api_token, family=family, run_id=run_id) + + # when + run.close() + + # then + assert True + + +def test_family_too_long(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + + # and + family = "a" * 1000 + + # when + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + +def test_run_id_too_long(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + family = str(uuid.uuid4()) + + # and + run_id = "a" * 1000 + + # then + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + +def test_invalid_project_name(): + # given + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # and + project = "just-project" + + # then + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... From b85f1c59784b494e052b7812f3eeee474856ab26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Mon, 29 Jul 2024 11:49:44 +0200 Subject: [PATCH 02/10] Added `OperationsQueue` component (#7) --- CHANGELOG.md | 1 + src/neptune_scale/__init__.py | 48 +++++++++---- src/neptune_scale/core/components/__init__.py | 0 src/neptune_scale/core/components/abstract.py | 52 ++++++++++++++ .../core/components/operations_queue.py | 69 +++++++++++++++++++ src/neptune_scale/parameters.py | 2 + tests/unit/test_operations_queue.py | 63 +++++++++++++++++ 7 files changed, 223 insertions(+), 12 deletions(-) create mode 100644 src/neptune_scale/core/components/__init__.py create mode 100644 src/neptune_scale/core/components/abstract.py create mode 100644 src/neptune_scale/core/components/operations_queue.py create mode 100644 tests/unit/test_operations_queue.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c2a541fd..6e600d17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,3 +9,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) +- Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index ed4b0929..9ceb82e2 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -6,9 +6,15 @@ __all__ = ["Run"] +import threading from contextlib import AbstractContextManager -from types import TracebackType +from typing import Callable +from neptune_scale.core.components.abstract import ( + Resource, + WithResources, +) +from neptune_scale.core.components.operations_queue import OperationsQueue from neptune_scale.core.validation import ( verify_max_length, verify_non_empty, @@ -17,16 +23,26 @@ ) from neptune_scale.parameters import ( MAX_FAMILY_LENGTH, + MAX_QUEUE_SIZE, MAX_RUN_ID_LENGTH, ) -class Run(AbstractContextManager): +class Run(WithResources, AbstractContextManager): """ Representation of tracked metadata. """ - def __init__(self, *, project: str, api_token: str, family: str, run_id: str) -> None: + def __init__( + self, + *, + project: str, + api_token: str, + family: str, + run_id: str, + max_queue_size: int = MAX_QUEUE_SIZE, + max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + ) -> None: """ Initializes a run that logs the model-building metadata to Neptune. @@ -36,10 +52,17 @@ def __init__(self, *, project: str, api_token: str, family: str, run_id: str) -> family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy. Max length: 128 characters. run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters. + max_queue_size: Maximum number of operations in a queue. + max_queue_size_exceeded_callback: Callback function triggered when a queue is full. + Accepts two arguments: + - Maximum size of the queue. + - Exception that made the queue full. """ verify_type("api_token", api_token, str) verify_type("family", family, str) verify_type("run_id", run_id, str) + verify_type("max_queue_size", max_queue_size, int) + verify_type("max_queue_size_exceeded_callback", max_queue_size_exceeded_callback, (Callable, type(None))) verify_non_empty("api_token", api_token) verify_non_empty("family", family) @@ -55,19 +78,20 @@ def __init__(self, *, project: str, api_token: str, family: str, run_id: str) -> self._family: str = family self._run_id: str = run_id + self._lock = threading.RLock() + self._operations_queue: OperationsQueue = OperationsQueue( + lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback + ) + def __enter__(self) -> Run: return self + @property + def resources(self) -> tuple[Resource, ...]: + return (self._operations_queue,) + def close(self) -> None: """ Stops the connection to Neptune and synchronizes all data. """ - pass - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self.close() + super().close() diff --git a/src/neptune_scale/core/components/__init__.py b/src/neptune_scale/core/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/core/components/abstract.py b/src/neptune_scale/core/components/abstract.py new file mode 100644 index 00000000..00242fa5 --- /dev/null +++ b/src/neptune_scale/core/components/abstract.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +from types import TracebackType + + +class AutoCloseable(ABC): + def __enter__(self) -> AutoCloseable: + return self + + @abstractmethod + def close(self) -> None: ... + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + +class Resource(AutoCloseable): + @abstractmethod + def cleanup(self) -> None: ... + + def flush(self) -> None: + pass + + def close(self) -> None: + self.flush() + + +class WithResources(Resource): + @property + @abstractmethod + def resources(self) -> tuple[Resource, ...]: ... + + def flush(self) -> None: + for resource in self.resources: + resource.flush() + + def close(self) -> None: + for resource in self.resources: + resource.close() + + def cleanup(self) -> None: + for resource in self.resources: + resource.cleanup() diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py new file mode 100644 index 00000000..a76f0f15 --- /dev/null +++ b/src/neptune_scale/core/components/operations_queue.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +__all__ = ("OperationsQueue",) + +from multiprocessing import Queue +from time import monotonic +from typing import ( + TYPE_CHECKING, + Callable, + NamedTuple, +) + +from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.validation import verify_type +from neptune_scale.parameters import MAX_QUEUE_ELEMENT_SIZE + +if TYPE_CHECKING: + from threading import RLock + + from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + + +class QueueElement(NamedTuple): + sequence_id: int + occured_at: float + operation: bytes + + +def default_max_size_exceeded_callback(max_size: int, e: BaseException) -> None: + raise ValueError(f"Queue is full (max size: {max_size})") from e + + +class OperationsQueue(Resource): + def __init__( + self, + *, + lock: RLock, + max_size: int = 0, + max_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + ) -> None: + verify_type("max_size", max_size, int) + + self._lock: RLock = lock + self._max_size: int = max_size + self._max_size_exceeded_callback: Callable[[int, BaseException], None] = ( + max_size_exceeded_callback if max_size_exceeded_callback is not None else default_max_size_exceeded_callback + ) + + self._sequence_id: int = 0 + self._queue: Queue[QueueElement] = Queue(maxsize=max_size) + + def enqueue(self, *, operation: RunOperation) -> None: + try: + with self._lock: + serialized_operation = operation.SerializeToString() + + if len(serialized_operation) > MAX_QUEUE_ELEMENT_SIZE: + raise ValueError(f"Operation size exceeds the maximum allowed size ({MAX_QUEUE_ELEMENT_SIZE})") + + self._queue.put_nowait(QueueElement(self._sequence_id, monotonic(), serialized_operation)) + self._sequence_id += 1 + except Exception as e: + self._max_size_exceeded_callback(self._max_size, e) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._queue.close() diff --git a/src/neptune_scale/parameters.py b/src/neptune_scale/parameters.py index 44444d86..44112374 100644 --- a/src/neptune_scale/parameters.py +++ b/src/neptune_scale/parameters.py @@ -1,2 +1,4 @@ MAX_RUN_ID_LENGTH = 128 MAX_FAMILY_LENGTH = 128 +MAX_QUEUE_SIZE = 32767 +MAX_QUEUE_ELEMENT_SIZE = 1024 * 1024 # 1MB diff --git a/tests/unit/test_operations_queue.py b/tests/unit/test_operations_queue.py new file mode 100644 index 00000000..f7c4d59a --- /dev/null +++ b/tests/unit/test_operations_queue.py @@ -0,0 +1,63 @@ +import threading +from unittest.mock import MagicMock + +import pytest +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.components.operations_queue import OperationsQueue + + +def test__enqueue(): + # given + lock = threading.RLock() + queue = OperationsQueue(lock=lock, max_size=0) + + # and + operation = RunOperation() + + # when + queue.enqueue(operation=operation) + + # then + assert queue._sequence_id == 1 + + # when + queue.enqueue(operation=operation) + + # then + assert queue._sequence_id == 2 + + +def test__max_queue_size_exceeded(): + # given + lock = threading.RLock() + callback = MagicMock() + queue = OperationsQueue(lock=lock, max_size=1, max_size_exceeded_callback=callback) + + # and + operation = RunOperation() + + # when + queue.enqueue(operation=operation) + queue.enqueue(operation=operation) + + # then + callback.assert_called_once() + + +def test__max_element_size_exceeded(): + # given + lock = threading.RLock() + queue = OperationsQueue(lock=lock, max_size=1) + + # and + snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=("a" * 1024)) for i in range(1024)}) + operation = RunOperation(update=snapshot) + + # then + with pytest.raises(ValueError): + queue.enqueue(operation=operation) From 2faf3a51ebe190209484ee2f57fa74b3394a74ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Wed, 31 Jul 2024 12:49:56 +0200 Subject: [PATCH 03/10] Logging metadata (#8) --- CHANGELOG.md | 1 + dev_requirements.txt | 1 + src/neptune_scale/__init__.py | 67 +++++++ .../core/components/operations_queue.py | 1 + src/neptune_scale/core/message_builder.py | 120 +++++++++++++ src/neptune_scale/core/validation.py | 19 +- tests/unit/test_message_builder.py | 169 ++++++++++++++++++ tests/unit/test_run.py | 84 +++++++++ 8 files changed, 456 insertions(+), 6 deletions(-) create mode 100644 src/neptune_scale/core/message_builder.py create mode 100644 tests/unit/test_message_builder.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e600d17..dc42d3d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,3 +10,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) - Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) +- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8)) diff --git a/dev_requirements.txt b/dev_requirements.txt index 7b8b416f..73ee4cb9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,3 +4,4 @@ pre-commit pytest pytest-timeout +freezegun diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 9ceb82e2..3e5551d7 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -8,6 +8,7 @@ import threading from contextlib import AbstractContextManager +from datetime import datetime from typing import Callable from neptune_scale.core.components.abstract import ( @@ -15,7 +16,9 @@ WithResources, ) from neptune_scale.core.components.operations_queue import OperationsQueue +from neptune_scale.core.message_builder import MessageBuilder from neptune_scale.core.validation import ( + verify_collection_type, verify_max_length, verify_non_empty, verify_project_qualified_name, @@ -95,3 +98,67 @@ def close(self) -> None: Stops the connection to Neptune and synchronizes all data. """ super().close() + + def log( + self, + step: float | int | None = None, + timestamp: datetime | None = None, + fields: dict[str, float | bool | int | str | datetime | list | set] | None = None, + metrics: dict[str, float] | None = None, + add_tags: dict[str, list[str] | set[str]] | None = None, + remove_tags: dict[str, list[str] | set[str]] | None = None, + ) -> None: + """ + Logs the specified metadata to Neptune. + + Args: + step: Index of the log entry, must be increasing. If None, the highest of the already logged indexes is used. + timestamp: Time of logging the metadata. + fields: Dictionary of fields to log. + metrics: Dictionary of metrics to log. + add_tags: Dictionary of tags to add to the run. + remove_tags: Dictionary of tags to remove from the run. + + Examples: + ``` + >>> with Run(...) as run: + ... run.log(step=1, fields={"parameters/learning_rate": 0.001}) + ... run.log(step=2, add_tags={"sys/group_tags": ["group1", "group2"]}) + ... run.log(step=3, metrics={"metrics/loss": 0.1}) + ``` + + """ + verify_type("step", step, (float, int, type(None))) + verify_type("timestamp", timestamp, (datetime, type(None))) + verify_type("fields", fields, (dict, type(None))) + verify_type("metrics", metrics, (dict, type(None))) + verify_type("add_tags", add_tags, (dict, type(None))) + verify_type("remove_tags", remove_tags, (dict, type(None))) + + timestamp = datetime.now() if timestamp is None else timestamp + fields = {} if fields is None else fields + metrics = {} if metrics is None else metrics + add_tags = {} if add_tags is None else add_tags + remove_tags = {} if remove_tags is None else remove_tags + + verify_collection_type("`fields` keys", list(fields.keys()), str) + verify_collection_type("`metrics` keys", list(metrics.keys()), str) + verify_collection_type("`add_tags` keys", list(add_tags.keys()), str) + verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str) + + verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set)) + verify_collection_type("`metrics` values", list(metrics.values()), float) + verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set)) + verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set)) + + for operation in MessageBuilder( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + ): + self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py index a76f0f15..bed30265 100644 --- a/src/neptune_scale/core/components/operations_queue.py +++ b/src/neptune_scale/core/components/operations_queue.py @@ -51,6 +51,7 @@ def __init__( def enqueue(self, *, operation: RunOperation) -> None: try: + # TODO: This lock could be moved to the Run class with self._lock: serialized_operation = operation.SerializeToString() diff --git a/src/neptune_scale/core/message_builder.py b/src/neptune_scale/core/message_builder.py new file mode 100644 index 00000000..75f21751 --- /dev/null +++ b/src/neptune_scale/core/message_builder.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +__all__ = ("MessageBuilder",) + +from datetime import datetime + +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + + +class MessageBuilder: + def __init__( + self, + *, + project: str, + run_id: str, + step: int | float | None, + timestamp: datetime, + fields: dict[str, float | bool | int | str | datetime | list | set], + metrics: dict[str, float], + add_tags: dict[str, list[str] | set[str]], + remove_tags: dict[str, list[str] | set[str]], + ): + self._step = None if step is None else make_step(number=step) + self._timestamp = datetime_to_proto(timestamp) + self._fields = fields + self._metrics = metrics + self._add_tags = add_tags + self._remove_tags = remove_tags + self._project = project + self._run_id = run_id + + self._was_produced: bool = False + + def __iter__(self) -> MessageBuilder: + return self + + def __next__(self) -> RunOperation: + if not self._was_produced: + self._was_produced = True + + modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()} + modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()}) + update = UpdateRunSnapshot( + step=self._step, + timestamp=self._timestamp, + assign={key: make_value(value) for key, value in self._fields.items()}, + append={key: make_value(value) for key, value in self._metrics.items()}, + modify_sets=modify_sets, + ) + + return RunOperation(project=self._project, run_id=self._run_id, update=update) + + raise StopIteration + + +def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet: + mod_set = ModifySet() + if add is not None: + for tag in add: + mod_set.string.values[tag] = SET_OPERATION.ADD + if remove is not None: + for tag in remove: + mod_set.string.values[tag] = SET_OPERATION.REMOVE + return mod_set + + +def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: + if isinstance(value, Value): + return value + if isinstance(value, float): + return Value(float64=value) + elif isinstance(value, bool): + return Value(bool=value) + elif isinstance(value, int): + return Value(int64=value) + elif isinstance(value, str): + return Value(string=value) + elif isinstance(value, datetime): + return Value(timestamp=datetime_to_proto(value)) + elif isinstance(value, (list, set)): + fv = Value(string_set=StringSet(values=value)) + return fv + else: + raise ValueError(f"Unsupported ingest field value type: {type(value)}") + + +def datetime_to_proto(dt: datetime) -> Timestamp: + dt_ts = dt.timestamp() + return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) + + +def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: + """ + Converts a number to protobuf Step value. Example: + >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) + Args: + number: step expressed as number + raise_on_step_precision_loss: inform converter whether it should silently drop precision and + round down to 6 decimal places or raise an error. + + Returns: Step protobuf used in Neptune API. + """ + m = int(1e6) + micro: int = int(number * m) + if raise_on_step_precision_loss and number * m - micro != 0: + raise ValueError(f"step must not use more than 6-decimal points, got: {number}") + + whole = micro // m + micro = micro % m + + return Step(whole=whole, micro=micro) diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py index efd8716a..95a732e1 100644 --- a/src/neptune_scale/core/validation.py +++ b/src/neptune_scale/core/validation.py @@ -1,21 +1,21 @@ +from __future__ import annotations + __all__ = ( "verify_type", "verify_non_empty", "verify_max_length", "verify_project_qualified_name", + "verify_collection_type", ) -from typing import ( - Any, - Union, -) +from typing import Any -def get_type_name(var_type: Union[type, tuple]) -> str: +def get_type_name(var_type: type | tuple) -> str: return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) -def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None: +def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None: try: if isinstance(expected_type, tuple): type_name = " or ".join(get_type_name(t) for t in expected_type) @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None: project_parts = var.split("/") if len(project_parts) != 2: raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") + + +def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None: + verify_type(var_name, var, (list, set, tuple)) + + for value in var: + verify_type(f"elements of collection '{var_name}'", value, expected_type) diff --git a/tests/unit/test_message_builder.py b/tests/unit/test_message_builder.py new file mode 100644 index 00000000..5d78ed6d --- /dev/null +++ b/tests/unit/test_message_builder.py @@ -0,0 +1,169 @@ +from datetime import datetime + +from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + ModifyStringSet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.message_builder import MessageBuilder + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_empty(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934)), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_fields(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={ + "some/string": "value", + "some/int": 2501, + "some/float": 3.14, + "some/bool": True, + "some/datetime": datetime.now(), + "some/tags": {"tag1", "tag2"}, + }, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + assign={ + "some/string": Value(string="value"), + "some/int": Value(int64=2501), + "some/float": Value(float64=3.14), + "some/bool": Value(bool=True), + "some/datetime": Value(timestamp=Timestamp(seconds=1722341532, nanos=21934)), + "some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_metrics(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={ + "some/metric": 3.14, + }, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + append={ + "some/metric": Value(float64=3.14), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_tags(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={ + "some/tags": {"tag1", "tag2"}, + "some/other_tags2": {"tag2", "tag3"}, + }, + remove_tags={ + "some/group_tags": {"tag0", "tag1"}, + "some/other_tags": {"tag2", "tag3"}, + }, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + modify_sets={ + "some/tags": ModifySet( + string=ModifyStringSet(values={"tag1": SET_OPERATION.ADD, "tag2": SET_OPERATION.ADD}) + ), + "some/other_tags2": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.ADD, "tag3": SET_OPERATION.ADD}) + ), + "some/group_tags": ModifySet( + string=ModifyStringSet(values={"tag0": SET_OPERATION.REMOVE, "tag1": SET_OPERATION.REMOVE}) + ), + "some/other_tags": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.REMOVE, "tag3": SET_OPERATION.REMOVE}) + ), + }, + ), + ) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index f7d8310c..0c7fb614 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime import pytest @@ -80,3 +81,86 @@ def test_invalid_project_name(): with pytest.raises(ValueError): with Run(project=project, api_token=api_token, family=family, run_id=run_id): ... + + +def test_metadata(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=1, + timestamp=datetime.now(), + fields={ + "int": 1, + "string": "test", + "float": 3.14, + "bool": True, + "datetime": datetime.now(), + }, + metrics={ + "metric": 1.0, + }, + add_tags={ + "tags": ["tag1"], + }, + remove_tags={ + "group_tags": ["tag2"], + }, + ) + + +def test_log_without_step(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +def test_log_step_float(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +def test_log_no_timestamp(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + fields={ + "int": 1, + }, + ) From 4c91b15d4146ea140a289cce62cbf29e3b4bc660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Wed, 31 Jul 2024 16:56:46 +0200 Subject: [PATCH 04/10] Run creation and basic data synchronization (#9) --- CHANGELOG.md | 4 + src/neptune_scale/__init__.py | 82 ++++++++++++++- src/neptune_scale/api/__init__.py | 0 src/neptune_scale/api/api_client.py | 87 +++++++++++++++ src/neptune_scale/core/message_builder.py | 34 +----- src/neptune_scale/core/proto_utils.py | 35 +++++++ tests/unit/test_run.py | 122 ++++++++++++++++++---- 7 files changed, 314 insertions(+), 50 deletions(-) create mode 100644 src/neptune_scale/api/__init__.py create mode 100644 src/neptune_scale/api/api_client.py create mode 100644 src/neptune_scale/core/proto_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dc42d3d5..7f46ad72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,3 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) - Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) - Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8)) +- Added support for `creation_time` ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Forking ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Experiments ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Run resume ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 3e5551d7..4411c8bb 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -11,12 +11,21 @@ from datetime import datetime from typing import Callable +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.api.api_client import ApiClient from neptune_scale.core.components.abstract import ( Resource, WithResources, ) from neptune_scale.core.components.operations_queue import OperationsQueue from neptune_scale.core.message_builder import MessageBuilder +from neptune_scale.core.proto_utils import ( + datetime_to_proto, + make_step, +) from neptune_scale.core.validation import ( verify_collection_type, verify_max_length, @@ -43,6 +52,11 @@ def __init__( api_token: str, family: str, run_id: str, + resume: bool = False, + as_experiment: str | None = None, + creation_time: datetime | None = None, + from_run_id: str | None = None, + from_step: int | float | None = None, max_queue_size: int = MAX_QUEUE_SIZE, max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, ) -> None: @@ -55,6 +69,11 @@ def __init__( family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy. Max length: 128 characters. run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters. + resume: Whether to resume an existing run. + as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run. + creation_time: Custom creation time of the run. + from_run_id: If forking from an existing run, ID of the run to fork from. + from_step: If forking from an existing run, step number to fork from. max_queue_size: Maximum number of operations in a queue. max_queue_size_exceeded_callback: Callback function triggered when a queue is full. Accepts two arguments: @@ -64,12 +83,32 @@ def __init__( verify_type("api_token", api_token, str) verify_type("family", family, str) verify_type("run_id", run_id, str) + verify_type("resume", resume, bool) + verify_type("as_experiment", as_experiment, (str, type(None))) + verify_type("creation_time", creation_time, (datetime, type(None))) + verify_type("from_run_id", from_run_id, (str, type(None))) + verify_type("from_step", from_step, (int, float, type(None))) verify_type("max_queue_size", max_queue_size, int) verify_type("max_queue_size_exceeded_callback", max_queue_size_exceeded_callback, (Callable, type(None))) + if resume and creation_time is not None: + raise ValueError("`resume` and `creation_time` cannot be used together.") + if resume and as_experiment is not None: + raise ValueError("`resume` and `as_experiment` cannot be used together.") + if (from_run_id is not None and from_step is None) or (from_run_id is None and from_step is not None): + raise ValueError("`from_run_id` and `from_step` must be used together.") + if resume and from_run_id is not None: + raise ValueError("`resume` and `from_run_id` cannot be used together.") + if resume and from_step is not None: + raise ValueError("`resume` and `from_step` cannot be used together.") + verify_non_empty("api_token", api_token) verify_non_empty("family", family) verify_non_empty("run_id", run_id) + if as_experiment is not None: + verify_non_empty("as_experiment", as_experiment) + if from_run_id is not None: + verify_non_empty("from_run_id", from_run_id) verify_project_qualified_name("project", project) @@ -77,7 +116,6 @@ def __init__( verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH) self._project: str = project - self._api_token: str = api_token self._family: str = family self._run_id: str = run_id @@ -85,13 +123,22 @@ def __init__( self._operations_queue: OperationsQueue = OperationsQueue( lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback ) + self._backend: ApiClient = ApiClient(api_token=api_token) + + if not resume: + self._create_run( + creation_time=datetime.now() if creation_time is None else creation_time, + as_experiment=as_experiment, + from_run_id=from_run_id, + from_step=from_step, + ) def __enter__(self) -> Run: return self @property def resources(self) -> tuple[Resource, ...]: - return (self._operations_queue,) + return self._operations_queue, self._backend def close(self) -> None: """ @@ -99,6 +146,33 @@ def close(self) -> None: """ super().close() + def _create_run( + self, + creation_time: datetime, + as_experiment: str | None, + from_run_id: str | None, + from_step: int | float | None, + ) -> None: + fork_point: ForkPoint | None = None + if from_run_id is not None and from_step is not None: + fork_point = ForkPoint( + parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step) + ) + + operation = RunOperation( + project=self._project, + run_id=self._run_id, + create=CreateRun( + family=self._family, + fork_point=fork_point, + experiment_id=as_experiment, + creation_time=None if creation_time is None else datetime_to_proto(creation_time), + ), + ) + self._backend.submit(operation=operation) + # TODO: Enqueue on the operations queue + # self._operations_queue.enqueue(operation=operation) + def log( self, step: float | int | None = None, @@ -161,4 +235,6 @@ def log( add_tags=add_tags, remove_tags=remove_tags, ): - self._operations_queue.enqueue(operation=operation) + self._backend.submit(operation=operation) + # TODO: Enqueue on the operations queue + # self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/api/__init__.py b/src/neptune_scale/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/api/api_client.py b/src/neptune_scale/api/api_client.py new file mode 100644 index 00000000..b8b68369 --- /dev/null +++ b/src/neptune_scale/api/api_client.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +__all__ = ["ApiClient"] + + +from dataclasses import dataclass + +from neptune_api import ( + AuthenticatedClient, + Client, +) +from neptune_api.api.backend import get_client_config +from neptune_api.api.data_ingestion import submit_operation +from neptune_api.auth_helpers import exchange_api_key +from neptune_api.credentials import Credentials +from neptune_api.models import ( + ClientConfig, + Error, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.components.abstract import Resource + + +class ApiClient(Resource): + def __init__(self, api_token: str) -> None: + credentials = Credentials.from_api_key(api_key=api_token) + config, token_urls = get_config_and_token_urls(credentials=credentials) + self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls) + + def submit(self, operation: RunOperation) -> None: + _ = submit_operation.sync(client=self._backend, body=operation) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._backend.__exit__() + + +@dataclass +class TokenRefreshingURLs: + authorization_endpoint: str + token_endpoint: str + + @classmethod + def from_dict(cls, data: dict) -> TokenRefreshingURLs: + return TokenRefreshingURLs( + authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"] + ) + + +def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig, TokenRefreshingURLs]: + with Client(base_url=credentials.base_url) as client: + config = get_client_config.sync(client=client) + if config is None or isinstance(config, Error): + raise RuntimeError(f"Failed to get client config: {config}") + response = client.get_httpx_client().get(config.security.open_id_discovery) + token_urls = TokenRefreshingURLs.from_dict(response.json()) + return config, token_urls + + +def create_auth_api_client( + *, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs +) -> AuthenticatedClient: + return AuthenticatedClient( + base_url=credentials.base_url, + credentials=credentials, + client_id=config.security.client_id, + token_refreshing_endpoint=token_refreshing_urls.token_endpoint, + api_key_exchange_callback=exchange_api_key, + ) diff --git a/src/neptune_scale/core/message_builder.py b/src/neptune_scale/core/message_builder.py index 75f21751..5bb6f05c 100644 --- a/src/neptune_scale/core/message_builder.py +++ b/src/neptune_scale/core/message_builder.py @@ -4,17 +4,20 @@ from datetime import datetime -from google.protobuf.timestamp_pb2 import Timestamp from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( SET_OPERATION, ModifySet, - Step, StringSet, UpdateRunSnapshot, Value, ) from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation +from neptune_scale.core.proto_utils import ( + datetime_to_proto, + make_step, +) + class MessageBuilder: def __init__( @@ -91,30 +94,3 @@ def make_value(value: Value | float | str | int | bool | datetime | list[str] | return fv else: raise ValueError(f"Unsupported ingest field value type: {type(value)}") - - -def datetime_to_proto(dt: datetime) -> Timestamp: - dt_ts = dt.timestamp() - return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) - - -def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: - """ - Converts a number to protobuf Step value. Example: - >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) - Args: - number: step expressed as number - raise_on_step_precision_loss: inform converter whether it should silently drop precision and - round down to 6 decimal places or raise an error. - - Returns: Step protobuf used in Neptune API. - """ - m = int(1e6) - micro: int = int(number * m) - if raise_on_step_precision_loss and number * m - micro != 0: - raise ValueError(f"step must not use more than 6-decimal points, got: {number}") - - whole = micro // m - micro = micro % m - - return Step(whole=whole, micro=micro) diff --git a/src/neptune_scale/core/proto_utils.py b/src/neptune_scale/core/proto_utils.py new file mode 100644 index 00000000..5fa72d5a --- /dev/null +++ b/src/neptune_scale/core/proto_utils.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +__all__ = ("datetime_to_proto", "make_step") + +from datetime import datetime + +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Step + + +def datetime_to_proto(dt: datetime) -> Timestamp: + dt_ts = dt.timestamp() + return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) + + +def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: + """ + Converts a number to protobuf Step value. Example: + >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) + Args: + number: step expressed as number + raise_on_step_precision_loss: inform converter whether it should silently drop precision and + round down to 6 decimal places or raise an error. + + Returns: Step protobuf used in Neptune API. + """ + m = int(1e6) + micro: int = int(number * m) + if raise_on_step_precision_loss and number * m - micro != 0: + raise ValueError(f"step must not use more than 6-decimal points, got: {number}") + + whole = micro // m + micro = micro % m + + return Step(whole=whole, micro=micro) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index 0c7fb614..799424e1 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -1,15 +1,38 @@ +import base64 +import json import uuid from datetime import datetime +from unittest.mock import patch import pytest +from freezegun import freeze_time from neptune_scale import Run -def test_context_manager(): +@pytest.fixture(scope="session") +def api_token(): + return base64.b64encode(json.dumps({"api_address": "aa", "api_url": "bb"}).encode("utf-8")).decode("utf-8") + + +class MockedApiClient: + def __init__(self, *args, **kwargs) -> None: + pass + + def submit(self, operation) -> None: + pass + + def close(self) -> None: + pass + + def cleanup(self) -> None: + pass + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_context_manager(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -21,10 +44,10 @@ def test_context_manager(): assert True -def test_close(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_close(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -38,10 +61,10 @@ def test_close(): assert True -def test_family_too_long(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_family_too_long(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) # and @@ -53,10 +76,10 @@ def test_family_too_long(): ... -def test_run_id_too_long(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_run_id_too_long(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" family = str(uuid.uuid4()) # and @@ -68,9 +91,9 @@ def test_run_id_too_long(): ... -def test_invalid_project_name(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_invalid_project_name(api_token): # given - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -83,10 +106,10 @@ def test_invalid_project_name(): ... -def test_metadata(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_metadata(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -114,10 +137,10 @@ def test_metadata(): ) -def test_log_without_step(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_without_step(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -131,10 +154,10 @@ def test_log_without_step(): ) -def test_log_step_float(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_step_float(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -149,10 +172,10 @@ def test_log_step_float(): ) -def test_log_no_timestamp(): +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_no_timestamp(api_token): # given project = "workspace/project" - api_token = "API_TOKEN" run_id = str(uuid.uuid4()) family = run_id @@ -164,3 +187,66 @@ def test_log_no_timestamp(): "int": 1, }, ) + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_resume(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, resume=True): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +@freeze_time("2024-07-30 12:12:12.000022") +def test_creation_time(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, creation_time=datetime.now()): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_assign_experiment(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, as_experiment="experiment_id"): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_forking(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run( + project=project, api_token=api_token, family=family, run_id=run_id, from_run_id="parent-run-id", from_step=3.14 + ): + ... + + # then + assert True From 67f63cbced79d5261cf5cb908c6008e7c92a1aa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Fri, 9 Aug 2024 19:55:25 +0200 Subject: [PATCH 05/10] Added support for env variables for project and api token (#11) --- CHANGELOG.md | 1 + src/neptune_scale/__init__.py | 31 ++++++++++++++++++++++++------- src/neptune_scale/envs.py | 3 +++ 3 files changed, 28 insertions(+), 7 deletions(-) create mode 100644 src/neptune_scale/envs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f46ad72..99ab9add 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,3 +15,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Forking ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) - Added support for Experiments ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) - Added support for Run resume ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for env variables for project and api token ([#11](https://github.com/neptune-ai/neptune-client-scale/pull/11)) diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 4411c8bb..e88bc1a3 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -6,6 +6,7 @@ __all__ = ["Run"] +import os import threading from contextlib import AbstractContextManager from datetime import datetime @@ -33,6 +34,10 @@ verify_project_qualified_name, verify_type, ) +from neptune_scale.envs import ( + API_TOKEN_ENV_NAME, + PROJECT_ENV_NAME, +) from neptune_scale.parameters import ( MAX_FAMILY_LENGTH, MAX_QUEUE_SIZE, @@ -48,10 +53,10 @@ class Run(WithResources, AbstractContextManager): def __init__( self, *, - project: str, - api_token: str, family: str, run_id: str, + project: str | None = None, + api_token: str | None = None, resume: bool = False, as_experiment: str | None = None, creation_time: datetime | None = None, @@ -64,11 +69,13 @@ def __init__( Initializes a run that logs the model-building metadata to Neptune. Args: - project: Name of the project where the metadata is logged, in the form `workspace-name/project-name`. - api_token: Your Neptune API token. family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy. Max length: 128 characters. run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters. + project: Name of the project where the metadata is logged, in the form `workspace-name/project-name`. + If not provided, the value of the `NEPTUNE_PROJECT` environment variable is used. + api_token: Your Neptune API token. If not provided, the value of the `NEPTUNE_API_TOKEN` environment + variable is used. resume: Whether to resume an existing run. as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run. creation_time: Custom creation time of the run. @@ -80,10 +87,11 @@ def __init__( - Maximum size of the queue. - Exception that made the queue full. """ - verify_type("api_token", api_token, str) verify_type("family", family, str) verify_type("run_id", run_id, str) verify_type("resume", resume, bool) + verify_type("project", project, (str, type(None))) + verify_type("api_token", api_token, (str, type(None))) verify_type("as_experiment", as_experiment, (str, type(None))) verify_type("creation_time", creation_time, (datetime, type(None))) verify_type("from_run_id", from_run_id, (str, type(None))) @@ -102,7 +110,16 @@ def __init__( if resume and from_step is not None: raise ValueError("`resume` and `from_step` cannot be used together.") + project = project or os.environ.get(PROJECT_ENV_NAME) + verify_non_empty("project", project) + assert project is not None # mypy + input_project: str = project + + api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) verify_non_empty("api_token", api_token) + assert api_token is not None # mypy + input_api_token: str = api_token + verify_non_empty("family", family) verify_non_empty("run_id", run_id) if as_experiment is not None: @@ -115,7 +132,7 @@ def __init__( verify_max_length("family", family, MAX_FAMILY_LENGTH) verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH) - self._project: str = project + self._project: str = input_project self._family: str = family self._run_id: str = run_id @@ -123,7 +140,7 @@ def __init__( self._operations_queue: OperationsQueue = OperationsQueue( lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback ) - self._backend: ApiClient = ApiClient(api_token=api_token) + self._backend: ApiClient = ApiClient(api_token=input_api_token) if not resume: self._create_run( diff --git a/src/neptune_scale/envs.py b/src/neptune_scale/envs.py new file mode 100644 index 00000000..02681d9b --- /dev/null +++ b/src/neptune_scale/envs.py @@ -0,0 +1,3 @@ +PROJECT_ENV_NAME = "NEPTUNE_PROJECT" + +API_TOKEN_ENV_NAME = "NEPTUNE_API_TOKEN" From 6e4ada2355877c1bab188228d0407ef648a3f1d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Fri, 9 Aug 2024 19:59:15 +0200 Subject: [PATCH 06/10] Splitting metadata into multiple messages on `log` (#12) --- .pre-commit-config.yaml | 1 + pyproject.toml | 2 +- src/neptune_scale/__init__.py | 10 +- src/neptune_scale/core/message_builder.py | 96 ----------- src/neptune_scale/core/metadata_splitter.py | 161 ++++++++++++++++++ .../core/{proto_utils.py => serialization.py} | 38 ++++- ...e_builder.py => test_metadata_splitter.py} | 108 +++++++++++- 7 files changed, 308 insertions(+), 108 deletions(-) delete mode 100644 src/neptune_scale/core/message_builder.py create mode 100644 src/neptune_scale/core/metadata_splitter.py rename src/neptune_scale/core/{proto_utils.py => serialization.py} (50%) rename tests/unit/{test_message_builder.py => test_metadata_splitter.py} (56%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86b15790..204078f2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,5 +28,6 @@ repos: pass_filenames: false additional_dependencies: - neptune-api==0.3.0 + - more-itertools default_language_version: python: python3 diff --git a/pyproject.toml b/pyproject.toml index 0ac2ca72..a9831192 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ pattern = "default-unprefixed" [tool.poetry.dependencies] python = "^3.8" -# Networking neptune-api = "0.3.0" +more-itertools = "^10.0.0" [tool.poetry] name = "neptune-client-scale" diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index e88bc1a3..bc028634 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -22,8 +22,8 @@ WithResources, ) from neptune_scale.core.components.operations_queue import OperationsQueue -from neptune_scale.core.message_builder import MessageBuilder -from neptune_scale.core.proto_utils import ( +from neptune_scale.core.metadata_splitter import MetadataSplitter +from neptune_scale.core.serialization import ( datetime_to_proto, make_step, ) @@ -242,7 +242,7 @@ def log( verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set)) verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set)) - for operation in MessageBuilder( + splitter: MetadataSplitter = MetadataSplitter( project=self._project, run_id=self._run_id, step=step, @@ -251,7 +251,9 @@ def log( metrics=metrics, add_tags=add_tags, remove_tags=remove_tags, - ): + ) + + for operation in splitter: self._backend.submit(operation=operation) # TODO: Enqueue on the operations queue # self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/core/message_builder.py b/src/neptune_scale/core/message_builder.py deleted file mode 100644 index 5bb6f05c..00000000 --- a/src/neptune_scale/core/message_builder.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -__all__ = ("MessageBuilder",) - -from datetime import datetime - -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( - SET_OPERATION, - ModifySet, - StringSet, - UpdateRunSnapshot, - Value, -) -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -from neptune_scale.core.proto_utils import ( - datetime_to_proto, - make_step, -) - - -class MessageBuilder: - def __init__( - self, - *, - project: str, - run_id: str, - step: int | float | None, - timestamp: datetime, - fields: dict[str, float | bool | int | str | datetime | list | set], - metrics: dict[str, float], - add_tags: dict[str, list[str] | set[str]], - remove_tags: dict[str, list[str] | set[str]], - ): - self._step = None if step is None else make_step(number=step) - self._timestamp = datetime_to_proto(timestamp) - self._fields = fields - self._metrics = metrics - self._add_tags = add_tags - self._remove_tags = remove_tags - self._project = project - self._run_id = run_id - - self._was_produced: bool = False - - def __iter__(self) -> MessageBuilder: - return self - - def __next__(self) -> RunOperation: - if not self._was_produced: - self._was_produced = True - - modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()} - modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()}) - update = UpdateRunSnapshot( - step=self._step, - timestamp=self._timestamp, - assign={key: make_value(value) for key, value in self._fields.items()}, - append={key: make_value(value) for key, value in self._metrics.items()}, - modify_sets=modify_sets, - ) - - return RunOperation(project=self._project, run_id=self._run_id, update=update) - - raise StopIteration - - -def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet: - mod_set = ModifySet() - if add is not None: - for tag in add: - mod_set.string.values[tag] = SET_OPERATION.ADD - if remove is not None: - for tag in remove: - mod_set.string.values[tag] = SET_OPERATION.REMOVE - return mod_set - - -def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: - if isinstance(value, Value): - return value - if isinstance(value, float): - return Value(float64=value) - elif isinstance(value, bool): - return Value(bool=value) - elif isinstance(value, int): - return Value(int64=value) - elif isinstance(value, str): - return Value(string=value) - elif isinstance(value, datetime): - return Value(timestamp=datetime_to_proto(value)) - elif isinstance(value, (list, set)): - fv = Value(string_set=StringSet(values=value)) - return fv - else: - raise ValueError(f"Unsupported ingest field value type: {type(value)}") diff --git a/src/neptune_scale/core/metadata_splitter.py b/src/neptune_scale/core/metadata_splitter.py new file mode 100644 index 00000000..1aba2656 --- /dev/null +++ b/src/neptune_scale/core/metadata_splitter.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +__all__ = ("MetadataSplitter",) + +from datetime import datetime +from typing import ( + Any, + Callable, + Iterator, + TypeVar, +) + +from more_itertools import peekable +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.serialization import ( + datetime_to_proto, + make_step, + make_value, + pb_key_size, +) + +T = TypeVar("T", bound=Any) + + +class MetadataSplitter(Iterator[RunOperation]): + def __init__( + self, + *, + project: str, + run_id: str, + step: int | float | None, + timestamp: datetime, + fields: dict[str, float | bool | int | str | datetime | list | set], + metrics: dict[str, float], + add_tags: dict[str, list[str] | set[str]], + remove_tags: dict[str, list[str] | set[str]], + max_message_bytes_size: int = 1024 * 1024, + ): + self._step = None if step is None else make_step(number=step) + self._timestamp = datetime_to_proto(timestamp) + self._project = project + self._run_id = run_id + self._fields = peekable(fields.items()) + self._metrics = peekable(metrics.items()) + self._add_tags = peekable(add_tags.items()) + self._remove_tags = peekable(remove_tags.items()) + + self._max_update_bytes_size = ( + max_message_bytes_size + - RunOperation( + project=self._project, + run_id=self._run_id, + update=UpdateRunSnapshot(step=self._step, timestamp=self._timestamp), + ).ByteSize() + ) + + self._has_returned = False + + def __iter__(self) -> MetadataSplitter: + self._has_returned = False + return self + + def __next__(self) -> RunOperation: + size = 0 + update = UpdateRunSnapshot( + step=self._step, + timestamp=self._timestamp, + assign={}, + append={}, + modify_sets={}, + ) + + size = self.populate( + assets=self._fields, + update_producer=lambda key, value: update.assign[key].MergeFrom(value), + size=size, + ) + size = self.populate( + assets=self._metrics, + update_producer=lambda key, value: update.append[key].MergeFrom(value), + size=size, + ) + size = self.populate_tags( + update=update, + assets=self._add_tags, + operation=SET_OPERATION.ADD, + size=size, + ) + _ = self.populate_tags( + update=update, + assets=self._remove_tags, + operation=SET_OPERATION.REMOVE, + size=size, + ) + + if not self._has_returned or update.assign or update.append or update.modify_sets: + self._has_returned = True + return RunOperation(project=self._project, run_id=self._run_id, update=update) + else: + raise StopIteration + + def populate( + self, + assets: peekable[Any], + update_producer: Callable[[str, Value], None], + size: int, + ) -> int: + while size < self._max_update_bytes_size: + try: + key, value = assets.peek() + except StopIteration: + break + + proto_value = make_value(value) + new_size = size + pb_key_size(key) + proto_value.ByteSize() + 6 + + if new_size > self._max_update_bytes_size: + break + + update_producer(key, proto_value) + size, _ = new_size, next(assets) + + return size + + def populate_tags( + self, update: UpdateRunSnapshot, assets: peekable[Any], operation: SET_OPERATION.ValueType, size: int + ) -> int: + while size < self._max_update_bytes_size: + try: + key, values = assets.peek() + except StopIteration: + break + + if not isinstance(values, peekable): + values = peekable(values) + + is_full = False + new_size = size + pb_key_size(key) + 6 + for value in values: + tag_size = pb_key_size(value) + 6 + if new_size + tag_size > self._max_update_bytes_size: + values.prepend(value) + is_full = True + break + + update.modify_sets[key].string.values[value] = operation + new_size += tag_size + + size, _ = new_size, next(assets) + + if is_full: + assets.prepend((key, list(values))) + break + + return size diff --git a/src/neptune_scale/core/proto_utils.py b/src/neptune_scale/core/serialization.py similarity index 50% rename from src/neptune_scale/core/proto_utils.py rename to src/neptune_scale/core/serialization.py index 5fa72d5a..5a3b1f23 100644 --- a/src/neptune_scale/core/proto_utils.py +++ b/src/neptune_scale/core/serialization.py @@ -1,11 +1,40 @@ from __future__ import annotations -__all__ = ("datetime_to_proto", "make_step") +__all__ = ( + "make_value", + "make_step", + "datetime_to_proto", + "pb_key_size", +) from datetime import datetime from google.protobuf.timestamp_pb2 import Timestamp -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Step +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + Step, + StringSet, + Value, +) + + +def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: + if isinstance(value, Value): + return value + if isinstance(value, float): + return Value(float64=value) + elif isinstance(value, bool): + return Value(bool=value) + elif isinstance(value, int): + return Value(int64=value) + elif isinstance(value, str): + return Value(string=value) + elif isinstance(value, datetime): + return Value(timestamp=datetime_to_proto(value)) + elif isinstance(value, (list, set)): + fv = Value(string_set=StringSet(values=value)) + return fv + else: + raise ValueError(f"Unsupported ingest field value type: {type(value)}") def datetime_to_proto(dt: datetime) -> Timestamp: @@ -33,3 +62,8 @@ def make_step(number: float | int, raise_on_step_precision_loss: bool = False) - micro = micro % m return Step(whole=whole, micro=micro) + + +def pb_key_size(key: str) -> int: + key_bin = bytes(key, "utf-8") + return len(key_bin) + 2 + (1 if len(key_bin) > 127 else 0) diff --git a/tests/unit/test_message_builder.py b/tests/unit/test_metadata_splitter.py similarity index 56% rename from tests/unit/test_message_builder.py rename to tests/unit/test_metadata_splitter.py index 5d78ed6d..ba8c4e37 100644 --- a/tests/unit/test_message_builder.py +++ b/tests/unit/test_metadata_splitter.py @@ -13,13 +13,13 @@ ) from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation -from neptune_scale.core.message_builder import MessageBuilder +from neptune_scale.core.metadata_splitter import MetadataSplitter @freeze_time("2024-07-30 12:12:12.000022") def test_empty(): # given - builder = MessageBuilder( + builder = MetadataSplitter( project="workspace/project", run_id="run_id", step=1, @@ -45,7 +45,7 @@ def test_empty(): @freeze_time("2024-07-30 12:12:12.000022") def test_fields(): # given - builder = MessageBuilder( + builder = MetadataSplitter( project="workspace/project", run_id="run_id", step=1, @@ -89,7 +89,7 @@ def test_fields(): @freeze_time("2024-07-30 12:12:12.000022") def test_metrics(): # given - builder = MessageBuilder( + builder = MetadataSplitter( project="workspace/project", run_id="run_id", step=1, @@ -123,7 +123,7 @@ def test_metrics(): @freeze_time("2024-07-30 12:12:12.000022") def test_tags(): # given - builder = MessageBuilder( + builder = MetadataSplitter( project="workspace/project", run_id="run_id", step=1, @@ -167,3 +167,101 @@ def test_tags(): }, ), ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_splitting(): + # given + max_size = 1024 + timestamp = datetime.now() + metrics = {f"metric{v}": 7 / 9.0 * v for v in range(1000)} + fields = {f"field{v}": v for v in range(1000)} + add_tags = {f"add/tag{v}": {f"value{v}"} for v in range(1000)} + remove_tags = {f"remove/tag{v}": {f"value{v}"} for v in range(1000)} + + # and + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + max_message_bytes_size=max_size, + ) + + # when + result = list(builder) + + # then + assert len(result) > 0 + + # Every message should be smaller than max_size + assert all(len(op.SerializeToString()) <= max_size for op in result) + + # Common metadata + assert all(op.project == "workspace/project" for op in result) + assert all(op.run_id == "run_id" for op in result) + assert all(op.update.step.whole == 1 for op in result) + assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) + + # Check if all metrics, fields and tags are present in the result + assert sorted([key for op in result for key in op.update.append.keys()]) == sorted(list(metrics.keys())) + assert sorted([key for op in result for key in op.update.assign.keys()]) == sorted(list(fields.keys())) + assert sorted([key for op in result for key in op.update.modify_sets.keys()]) == sorted( + list(add_tags.keys()) + list(remove_tags.keys()) + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_split_large_tags(): + # given + max_size = 1024 + timestamp = datetime.now() + metrics = {} + fields = {} + add_tags = {"add/tag": {f"value{v}" for v in range(1000)}} + remove_tags = {"remove/tag": {f"value{v}" for v in range(1000)}} + + # and + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + max_message_bytes_size=max_size, + ) + + # when + result = list(builder) + + # then + assert len(result) > 0 + + # Every message should be smaller than max_size + assert all(len(op.SerializeToString()) <= max_size for op in result) + + # Common metadata + assert all(op.project == "workspace/project" for op in result) + assert all(op.run_id == "run_id" for op in result) + assert all(op.update.step.whole == 1 for op in result) + assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) + + # Check if all StringSet values are split correctly + assert set([key for op in result for key in op.update.modify_sets.keys()]) == set( + list(add_tags.keys()) + list(remove_tags.keys()) + ) + + # Check if all tags are present in the result + assert {tag for op in result for tag in op.update.modify_sets["add/tag"].string.values.keys()} == add_tags[ + "add/tag" + ] + assert {tag for op in result for tag in op.update.modify_sets["remove/tag"].string.values.keys()} == remove_tags[ + "remove/tag" + ] From d8098f9b183d67ffb9fa51eb9a2e3e5847505ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Mon, 12 Aug 2024 11:18:35 +0200 Subject: [PATCH 07/10] Added ErrorsMonitor and ErrorsQueue (#13) --- src/neptune_scale/__init__.py | 13 ++- src/neptune_scale/core/components/daemon.py | 84 +++++++++++++++++++ .../core/components/errors_monitor.py | 46 ++++++++++ .../core/components/errors_queue.py | 24 ++++++ tests/unit/test_errors_monitor.py | 22 +++++ 5 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 src/neptune_scale/core/components/daemon.py create mode 100644 src/neptune_scale/core/components/errors_monitor.py create mode 100644 src/neptune_scale/core/components/errors_queue.py create mode 100644 tests/unit/test_errors_monitor.py diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index bc028634..ed80271f 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -21,6 +21,8 @@ Resource, WithResources, ) +from neptune_scale.core.components.errors_monitor import ErrorsMonitor +from neptune_scale.core.components.errors_queue import ErrorsQueue from neptune_scale.core.components.operations_queue import OperationsQueue from neptune_scale.core.metadata_splitter import MetadataSplitter from neptune_scale.core.serialization import ( @@ -140,8 +142,12 @@ def __init__( self._operations_queue: OperationsQueue = OperationsQueue( lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback ) + self._errors_queue: ErrorsQueue = ErrorsQueue() + self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue) self._backend: ApiClient = ApiClient(api_token=input_api_token) + self._errors_monitor.start() + if not resume: self._create_run( creation_time=datetime.now() if creation_time is None else creation_time, @@ -155,7 +161,12 @@ def __enter__(self) -> Run: @property def resources(self) -> tuple[Resource, ...]: - return self._operations_queue, self._backend + return ( + self._operations_queue, + self._backend, + self._errors_monitor, + self._errors_queue, + ) def close(self) -> None: """ diff --git a/src/neptune_scale/core/components/daemon.py b/src/neptune_scale/core/components/daemon.py new file mode 100644 index 00000000..d2ef8713 --- /dev/null +++ b/src/neptune_scale/core/components/daemon.py @@ -0,0 +1,84 @@ +__all__ = ["Daemon"] + +import abc +import threading +from enum import Enum + + +class Daemon(threading.Thread): + class DaemonState(Enum): + INIT = 1 + WORKING = 2 + PAUSING = 3 + PAUSED = 4 + INTERRUPTED = 5 + STOPPED = 6 + + def __init__(self, sleep_time: float, name: str) -> None: + super().__init__(daemon=True, name=name) + self._sleep_time = sleep_time + self._state: Daemon.DaemonState = Daemon.DaemonState.INIT + self._wait_condition = threading.Condition() + + def interrupt(self) -> None: + with self._wait_condition: + self._state = Daemon.DaemonState.INTERRUPTED + self._wait_condition.notify_all() + + def pause(self) -> None: + with self._wait_condition: + if self._state != Daemon.DaemonState.PAUSED: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.PAUSING + self._wait_condition.notify_all() + self._wait_condition.wait_for(lambda: self._state != Daemon.DaemonState.PAUSING) + + def resume(self) -> None: + with self._wait_condition: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.WORKING + self._wait_condition.notify_all() + + def wake_up(self) -> None: + with self._wait_condition: + self._wait_condition.notify_all() + + def disable_sleep(self) -> None: + self._sleep_time = 0 + + def is_running(self) -> bool: + with self._wait_condition: + return self._state in ( + Daemon.DaemonState.WORKING, + Daemon.DaemonState.PAUSING, + Daemon.DaemonState.PAUSED, + ) + + def _is_interrupted(self) -> bool: + with self._wait_condition: + return self._state in (Daemon.DaemonState.INTERRUPTED, Daemon.DaemonState.STOPPED) + + def run(self) -> None: + with self._wait_condition: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.WORKING + try: + while not self._is_interrupted(): + with self._wait_condition: + if self._state == Daemon.DaemonState.PAUSING: + self._state = Daemon.DaemonState.PAUSED + self._wait_condition.notify_all() + self._wait_condition.wait_for(lambda: self._state != Daemon.DaemonState.PAUSED) + + if self._state == Daemon.DaemonState.WORKING: + self.work() + with self._wait_condition: + if self._sleep_time > 0 and self._state == Daemon.DaemonState.WORKING: + self._wait_condition.wait(timeout=self._sleep_time) + finally: + with self._wait_condition: + self._state = Daemon.DaemonState.STOPPED + self._wait_condition.notify_all() + + @abc.abstractmethod + def work(self) -> None: ... diff --git a/src/neptune_scale/core/components/errors_monitor.py b/src/neptune_scale/core/components/errors_monitor.py new file mode 100644 index 00000000..dc9950be --- /dev/null +++ b/src/neptune_scale/core/components/errors_monitor.py @@ -0,0 +1,46 @@ +__all__ = ("ErrorsMonitor",) + +import logging +import queue +from typing import Callable + +from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.components.daemon import Daemon +from neptune_scale.core.components.errors_queue import ErrorsQueue + +logger = logging.getLogger("neptune") +logger.setLevel(level=logging.INFO) + + +def on_error(error: BaseException) -> None: + logger.error(error) + + +class ErrorsMonitor(Daemon, Resource): + def __init__( + self, + errors_queue: ErrorsQueue, + on_error_callback: Callable[[BaseException], None] = on_error, + ): + super().__init__(name="ErrorsMonitor", sleep_time=2) + self._errors_queue = errors_queue + self._on_error_callback = on_error_callback + + def work(self) -> None: + try: + error = self._errors_queue.get(block=False) + if error is not None: + self._on_error_callback(error) + except KeyboardInterrupt: + with self._wait_condition: + self._wait_condition.notify_all() + raise + except queue.Empty: + pass + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self.interrupt() + self.join(timeout=10) diff --git a/src/neptune_scale/core/components/errors_queue.py b/src/neptune_scale/core/components/errors_queue.py new file mode 100644 index 00000000..33bdc38e --- /dev/null +++ b/src/neptune_scale/core/components/errors_queue.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +__all__ = ("ErrorsQueue",) + +from multiprocessing import Queue + +from neptune_scale.core.components.abstract import Resource + + +class ErrorsQueue(Resource): + def __init__(self) -> None: + self._errors_queue: Queue[BaseException] = Queue() + + def put(self, error: BaseException) -> None: + self._errors_queue.put(error) + + def get(self, block: bool = True, timeout: float | None = None) -> BaseException: + return self._errors_queue.get(block=block, timeout=timeout) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._errors_queue.close() diff --git a/tests/unit/test_errors_monitor.py b/tests/unit/test_errors_monitor.py new file mode 100644 index 00000000..e4352d7e --- /dev/null +++ b/tests/unit/test_errors_monitor.py @@ -0,0 +1,22 @@ +from unittest.mock import Mock + +from neptune_scale.core.components.errors_monitor import ErrorsMonitor +from neptune_scale.core.components.errors_queue import ErrorsQueue + + +def test_errors_monitor(): + # given + callback = Mock() + + # and + errors_queue = ErrorsQueue() + errors_monitor = ErrorsMonitor(errors_queue=errors_queue, on_error_callback=callback) + + # when + errors_queue.put(ValueError("error1")) + errors_monitor.start() + errors_monitor.interrupt() + errors_monitor.join(timeout=1) + + # then + callback.assert_called() From e35876cb753408478cc17fee7e0d12b4100875c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Tue, 13 Aug 2024 10:39:21 +0200 Subject: [PATCH 08/10] Added support for family parameter (#14) --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- src/neptune_scale/__init__.py | 4 ++-- src/neptune_scale/api/api_client.py | 4 ++-- tests/unit/test_run.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 204078f2..4ac066f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.3.0 + - neptune-api==0.4.0 - more-itertools default_language_version: python: python3 diff --git a/pyproject.toml b/pyproject.toml index a9831192..f5fab8f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ pattern = "default-unprefixed" [tool.poetry.dependencies] python = "^3.8" -neptune-api = "0.3.0" +neptune-api = "0.4.0" more-itertools = "^10.0.0" [tool.poetry] diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index ed80271f..ec1a4127 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -197,7 +197,7 @@ def _create_run( creation_time=None if creation_time is None else datetime_to_proto(creation_time), ), ) - self._backend.submit(operation=operation) + self._backend.submit(operation=operation, family=self._family) # TODO: Enqueue on the operations queue # self._operations_queue.enqueue(operation=operation) @@ -265,6 +265,6 @@ def log( ) for operation in splitter: - self._backend.submit(operation=operation) + self._backend.submit(operation=operation, family=self._family) # TODO: Enqueue on the operations queue # self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/api/api_client.py b/src/neptune_scale/api/api_client.py index b8b68369..80a15d31 100644 --- a/src/neptune_scale/api/api_client.py +++ b/src/neptune_scale/api/api_client.py @@ -43,8 +43,8 @@ def __init__(self, api_token: str) -> None: config, token_urls = get_config_and_token_urls(credentials=credentials) self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls) - def submit(self, operation: RunOperation) -> None: - _ = submit_operation.sync(client=self._backend, body=operation) + def submit(self, operation: RunOperation, family: str) -> None: + _ = submit_operation.sync(client=self._backend, family=family, body=operation) def cleanup(self) -> None: pass diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index 799424e1..a3cb8dc9 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -19,7 +19,7 @@ class MockedApiClient: def __init__(self, *args, **kwargs) -> None: pass - def submit(self, operation) -> None: + def submit(self, operation, family) -> None: pass def close(self) -> None: From cceddab343db7f417d0823c223c750e2beb3b4e1 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Wed, 21 Aug 2024 09:23:44 +0200 Subject: [PATCH 09/10] Code review --- src/neptune_scale/__init__.py | 3 +++ tests/unit/test_metadata_splitter.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index ec1a4127..2356e5ac 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -112,6 +112,9 @@ def __init__( if resume and from_step is not None: raise ValueError("`resume` and `from_step` cannot be used together.") + if max_queue_size < 1: + raise ValueError("`max_queue_size` must be greater than 0.") + project = project or os.environ.get(PROJECT_ENV_NAME) verify_non_empty("project", project) assert project is not None # mypy diff --git a/tests/unit/test_metadata_splitter.py b/tests/unit/test_metadata_splitter.py index ba8c4e37..4d842506 100644 --- a/tests/unit/test_metadata_splitter.py +++ b/tests/unit/test_metadata_splitter.py @@ -196,7 +196,7 @@ def test_splitting(): result = list(builder) # then - assert len(result) > 0 + assert len(result) > 1 # Every message should be smaller than max_size assert all(len(op.SerializeToString()) <= max_size for op in result) @@ -242,7 +242,7 @@ def test_split_large_tags(): result = list(builder) # then - assert len(result) > 0 + assert len(result) > 1 # Every message should be smaller than max_size assert all(len(op.SerializeToString()) <= max_size for op in result) From 04a7dcef7f4a81d8f01d21640d9642ff351edb50 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Wed, 21 Aug 2024 09:35:01 +0200 Subject: [PATCH 10/10] Code review 2 --- src/neptune_scale/core/serialization.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/neptune_scale/core/serialization.py b/src/neptune_scale/core/serialization.py index 5a3b1f23..0858d8bc 100644 --- a/src/neptune_scale/core/serialization.py +++ b/src/neptune_scale/core/serialization.py @@ -65,5 +65,9 @@ def make_step(number: float | int, raise_on_step_precision_loss: bool = False) - def pb_key_size(key: str) -> int: + """ + Calculates the size of the string in the protobuf message including an overhead of the length prefix (varint) + with an assumption of maximal string length. + """ key_bin = bytes(key, "utf-8") return len(key_bin) + 2 + (1 if len(key_bin) > 127 else 0)