From 8013147c68f9c7ca70e540bfcd79a58a3bcfa0fb Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 30 Jul 2024 12:26:38 +0200 Subject: [PATCH] Minimal working implementation --- dev_requirements.txt | 1 + src/neptune_scale/__init__.py | 60 +++++++ src/neptune_scale/api/__init__.py | 0 src/neptune_scale/api/helpers.py | 97 +++++++++++ .../core/components/operations_queue.py | 1 + src/neptune_scale/core/message_builder.py | 62 +++++++ src/neptune_scale/core/validation.py | 19 ++- tests/unit/test_message_builder.py | 157 ++++++++++++++++++ tests/unit/test_run.py | 13 ++ 9 files changed, 404 insertions(+), 6 deletions(-) create mode 100644 src/neptune_scale/api/__init__.py create mode 100644 src/neptune_scale/api/helpers.py create mode 100644 src/neptune_scale/core/message_builder.py create mode 100644 tests/unit/test_message_builder.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 7b8b416f..73ee4cb9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,3 +4,4 @@ pre-commit pytest pytest-timeout +freezegun diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 9ceb82e2..9633318e 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -8,14 +8,19 @@ import threading from contextlib import AbstractContextManager +from datetime import datetime from typing import Callable +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + from neptune_scale.core.components.abstract import ( Resource, WithResources, ) from neptune_scale.core.components.operations_queue import OperationsQueue +from neptune_scale.core.message_builder import MessageBuilder from neptune_scale.core.validation import ( + verify_collection_type, verify_max_length, verify_non_empty, verify_project_qualified_name, @@ -90,6 +95,61 @@ def __enter__(self) -> Run: def resources(self) -> tuple[Resource, ...]: return (self._operations_queue,) + def _prepare_common_message(self) -> RunOperation: + return RunOperation( + project=self._project, + run_id=self._run_id, + create_missing_project=False, + api_key=b"", + ) + + def log( + self, + step: float | int | None = None, + timestamp: datetime | None = None, + fields: dict[str, float | int | str] | None = None, + metrics: dict[str, float] | None = None, + add_tags: dict[str, list[str] | set[str]] | None = None, + remove_tags: dict[str, list[str] | set[str]] | None = None, + ) -> None: + """ + TODO: Add description + """ + verify_type("step", step, (float, int, type(None))) + verify_type("timestamp", timestamp, (datetime, type(None))) + verify_type("fields", fields, (dict, type(None))) + verify_type("metrics", metrics, (dict, type(None))) + verify_type("add_tags", add_tags, (dict, type(None))) + verify_type("remove_tags", remove_tags, (dict, type(None))) + + timestamp = datetime.now() if timestamp is None else timestamp + fields = {} if fields is None else fields + metrics = {} if metrics is None else metrics + add_tags = {} if add_tags is None else add_tags + remove_tags = {} if remove_tags is None else remove_tags + + verify_collection_type("`fields` keys", list(fields.keys()), str) + verify_collection_type("`metrics` keys", list(metrics.keys()), str) + verify_collection_type("`add_tags` keys", list(add_tags.keys()), str) + verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str) + + # TODO: More types for values + verify_collection_type("`fields` values", list(fields.values()), (float, int, str)) + verify_collection_type("`metrics` values", list(metrics.values()), float) + verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set)) + verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set)) + + for operation in MessageBuilder( + common_message=self._prepare_common_message(), + step=step, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + ): + self._operations_queue.enqueue(operation=operation) + def close(self) -> None: """ Stops the connection to Neptune and synchronizes all data. diff --git a/src/neptune_scale/api/__init__.py b/src/neptune_scale/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/api/helpers.py b/src/neptune_scale/api/helpers.py new file mode 100644 index 00000000..4ca93ab4 --- /dev/null +++ b/src/neptune_scale/api/helpers.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +__all__ = ("datetime_to_proto", "make_step", "mod_tags") + +from datetime import datetime + +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + Step, + StringSet, + Value, +) + + +def datetime_to_proto(dt: datetime) -> Timestamp: + """ + Converts datetime object to google.protobuf.Timestamp used in Neptune Ingest API + Args: + dt: datetime - timezone is irrelevant and will be lost to unix time + Returns: protobuf Timestamp + """ + dt_ts = dt.timestamp() + return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) + + +def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: + """ + Converts a number to protobuf Step value. Example: + >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) + Args: + number: step expressed as number + raise_on_step_precision_loss: inform converter whether it should silently drop precision and + round down to 6 decimal places or raise an error. + + Returns: Step protobuf used in Neptune Ingest API. + """ + m = int(1e6) + micro: int = int(number * m) + if raise_on_step_precision_loss and number * m - micro != 0: + raise ValueError(f"step must not use more than 6-decimal points, got: {number}") + + whole = micro // m + micro = micro % m + + return Step(whole=whole, micro=micro) + + +def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: + """ + Converts python value to Ingest API Value protobuf message. + Make sure not to mix float and int types in the same field, especially literals - this may + cause data points to be dropped. + """ + if isinstance(value, Value): + return value + if isinstance(value, float): + return Value(float64=value) + elif isinstance(value, bool): + return Value(bool=value) + elif isinstance(value, int): + return Value(int64=value) + elif isinstance(value, str): + return Value(string=value) + elif isinstance(value, datetime): + return Value(timestamp=datetime_to_proto(value)) + elif isinstance(value, (list, set)): + fv = Value(string_set=StringSet(values=value)) + return fv + else: + raise ValueError(f"Unsupported ingest field value type: {type(value)}") + + +def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet: + """ + Shorthand to apply string set modifications. Example: + >>> from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import * + ... + ... update = UpdateRunSnapshot( + ... modify_sets={ + ... "sys/tags": mod_tags(add=["tag1", "tag2"], remove=["tag0"]) + ... }) + Args: + add: list of tags to be added + remove: list tags to be removed + + Returns: ModifySet with changes + """ + mod_set = ModifySet() + if add is not None: + for tag in add: + mod_set.string.values[tag] = SET_OPERATION.ADD + if remove is not None: + for tag in remove: + mod_set.string.values[tag] = SET_OPERATION.REMOVE + return mod_set diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py index a76f0f15..bed30265 100644 --- a/src/neptune_scale/core/components/operations_queue.py +++ b/src/neptune_scale/core/components/operations_queue.py @@ -51,6 +51,7 @@ def __init__( def enqueue(self, *, operation: RunOperation) -> None: try: + # TODO: This lock could be moved to the Run class with self._lock: serialized_operation = operation.SerializeToString() diff --git a/src/neptune_scale/core/message_builder.py b/src/neptune_scale/core/message_builder.py new file mode 100644 index 00000000..2dded63e --- /dev/null +++ b/src/neptune_scale/core/message_builder.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +__all__ = ("MessageBuilder",) + +from datetime import datetime + +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import UpdateRunSnapshot +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.api.helpers import ( + datetime_to_proto, + make_step, + make_value, + mod_tags, +) + + +class MessageBuilder: + def __init__( + self, + *, + common_message: RunOperation, + step: int | float | None, + timestamp: datetime, + fields: dict[str, float | int | str], + metrics: dict[str, float], + add_tags: dict[str, list[str] | set[str]], + remove_tags: dict[str, list[str] | set[str]], + ): + # TODO: Warning instead of raise on step precision loss + self._step = None if step is None else make_step(number=step, raise_on_step_precision_loss=True) + self._timestamp = datetime_to_proto(timestamp) + self._fields = fields + self._metrics = metrics + self._add_tags = add_tags + self._remove_tags = remove_tags + self._common_message = common_message + + self._ith: int = 0 + + def __iter__(self) -> MessageBuilder: + return self + + def __next__(self) -> RunOperation: + if self._ith == 0: + self._ith += 1 + + operation = self._common_message + update = UpdateRunSnapshot( + step=self._step, + timestamp=self._timestamp, + assign={key: make_value(value) for key, value in self._fields.items()}, + append={key: make_value(value) for key, value in self._metrics.items()}, + modify_sets={key: mod_tags(add=add) for key, add in self._add_tags.items()}, + ) + remove_tags = UpdateRunSnapshot( + modify_sets={key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()} + ) + update.MergeFrom(remove_tags) + operation.MergeFrom(RunOperation(update=update)) + return operation + raise StopIteration diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py index efd8716a..95a732e1 100644 --- a/src/neptune_scale/core/validation.py +++ b/src/neptune_scale/core/validation.py @@ -1,21 +1,21 @@ +from __future__ import annotations + __all__ = ( "verify_type", "verify_non_empty", "verify_max_length", "verify_project_qualified_name", + "verify_collection_type", ) -from typing import ( - Any, - Union, -) +from typing import Any -def get_type_name(var_type: Union[type, tuple]) -> str: +def get_type_name(var_type: type | tuple) -> str: return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) -def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None: +def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None: try: if isinstance(expected_type, tuple): type_name = " or ".join(get_type_name(t) for t in expected_type) @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None: project_parts = var.split("/") if len(project_parts) != 2: raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") + + +def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None: + verify_type(var_name, var, (list, set, tuple)) + + for value in var: + verify_type(f"elements of collection '{var_name}'", value, expected_type) diff --git a/tests/unit/test_message_builder.py b/tests/unit/test_message_builder.py new file mode 100644 index 00000000..6e0f43bd --- /dev/null +++ b/tests/unit/test_message_builder.py @@ -0,0 +1,157 @@ +from datetime import datetime + +from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + ModifyStringSet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.message_builder import MessageBuilder + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_empty(): + # given + builder = MessageBuilder( + common_message=RunOperation(), + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934)) + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_fields(): + # given + builder = MessageBuilder( + common_message=RunOperation(), + step=1, + timestamp=datetime.now(), + fields={ + "some/string": "value", + "some/int": 2501, + "some/float": 3.14, + "some/bool": True, + "some/datetime": datetime.now(), + "some/tags": {"tag1", "tag2"}, + }, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + assign={ + "some/string": Value(string="value"), + "some/int": Value(int64=2501), + "some/float": Value(float64=3.14), + "some/bool": Value(bool=True), + "some/datetime": Value(timestamp=Timestamp(seconds=1722341532, nanos=21934)), + "some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})), + }, + ) + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_metrics(): + # given + builder = MessageBuilder( + common_message=RunOperation(), + step=1, + timestamp=datetime.now(), + fields={}, + metrics={ + "some/metric": 3.14, + }, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + append={ + "some/metric": Value(float64=3.14), + }, + ) + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_tags(): + # given + builder = MessageBuilder( + common_message=RunOperation(), + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={ + "some/tags": {"tag1", "tag2"}, + "some/other_tags2": {"tag2", "tag3"}, + }, + remove_tags={ + "some/group_tags": {"tag0", "tag1"}, + "some/other_tags": {"tag2", "tag3"}, + }, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + modify_sets={ + "some/tags": ModifySet( + string=ModifyStringSet(values={"tag1": SET_OPERATION.ADD, "tag2": SET_OPERATION.ADD}) + ), + "some/other_tags2": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.ADD, "tag3": SET_OPERATION.ADD}) + ), + "some/group_tags": ModifySet( + string=ModifyStringSet(values={"tag0": SET_OPERATION.REMOVE, "tag1": SET_OPERATION.REMOVE}) + ), + "some/other_tags": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.REMOVE, "tag3": SET_OPERATION.REMOVE}) + ), + }, + ) + ) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index f7d8310c..780a5ac3 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime import pytest @@ -80,3 +81,15 @@ def test_invalid_project_name(): with pytest.raises(ValueError): with Run(project=project, api_token=api_token, family=family, run_id=run_id): ... + + +def test__metadata(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log(step=1, timestamp=datetime.now(), fields={"test": 1})