diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e600d17..dc42d3d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,3 +10,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) - Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) +- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8)) diff --git a/dev_requirements.txt b/dev_requirements.txt index 7b8b416f..73ee4cb9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,3 +4,4 @@ pre-commit pytest pytest-timeout +freezegun diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 9ceb82e2..3e5551d7 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -8,6 +8,7 @@ import threading from contextlib import AbstractContextManager +from datetime import datetime from typing import Callable from neptune_scale.core.components.abstract import ( @@ -15,7 +16,9 @@ WithResources, ) from neptune_scale.core.components.operations_queue import OperationsQueue +from neptune_scale.core.message_builder import MessageBuilder from neptune_scale.core.validation import ( + verify_collection_type, verify_max_length, verify_non_empty, verify_project_qualified_name, @@ -95,3 +98,67 @@ def close(self) -> None: Stops the connection to Neptune and synchronizes all data. """ super().close() + + def log( + self, + step: float | int | None = None, + timestamp: datetime | None = None, + fields: dict[str, float | bool | int | str | datetime | list | set] | None = None, + metrics: dict[str, float] | None = None, + add_tags: dict[str, list[str] | set[str]] | None = None, + remove_tags: dict[str, list[str] | set[str]] | None = None, + ) -> None: + """ + Logs the specified metadata to Neptune. + + Args: + step: Index of the log entry, must be increasing. If None, the highest of the already logged indexes is used. + timestamp: Time of logging the metadata. + fields: Dictionary of fields to log. + metrics: Dictionary of metrics to log. + add_tags: Dictionary of tags to add to the run. + remove_tags: Dictionary of tags to remove from the run. + + Examples: + ``` + >>> with Run(...) as run: + ... run.log(step=1, fields={"parameters/learning_rate": 0.001}) + ... run.log(step=2, add_tags={"sys/group_tags": ["group1", "group2"]}) + ... run.log(step=3, metrics={"metrics/loss": 0.1}) + ``` + + """ + verify_type("step", step, (float, int, type(None))) + verify_type("timestamp", timestamp, (datetime, type(None))) + verify_type("fields", fields, (dict, type(None))) + verify_type("metrics", metrics, (dict, type(None))) + verify_type("add_tags", add_tags, (dict, type(None))) + verify_type("remove_tags", remove_tags, (dict, type(None))) + + timestamp = datetime.now() if timestamp is None else timestamp + fields = {} if fields is None else fields + metrics = {} if metrics is None else metrics + add_tags = {} if add_tags is None else add_tags + remove_tags = {} if remove_tags is None else remove_tags + + verify_collection_type("`fields` keys", list(fields.keys()), str) + verify_collection_type("`metrics` keys", list(metrics.keys()), str) + verify_collection_type("`add_tags` keys", list(add_tags.keys()), str) + verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str) + + verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set)) + verify_collection_type("`metrics` values", list(metrics.values()), float) + verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set)) + verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set)) + + for operation in MessageBuilder( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + ): + self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py index a76f0f15..bed30265 100644 --- a/src/neptune_scale/core/components/operations_queue.py +++ b/src/neptune_scale/core/components/operations_queue.py @@ -51,6 +51,7 @@ def __init__( def enqueue(self, *, operation: RunOperation) -> None: try: + # TODO: This lock could be moved to the Run class with self._lock: serialized_operation = operation.SerializeToString() diff --git a/src/neptune_scale/core/message_builder.py b/src/neptune_scale/core/message_builder.py new file mode 100644 index 00000000..75f21751 --- /dev/null +++ b/src/neptune_scale/core/message_builder.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +__all__ = ("MessageBuilder",) + +from datetime import datetime + +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + + +class MessageBuilder: + def __init__( + self, + *, + project: str, + run_id: str, + step: int | float | None, + timestamp: datetime, + fields: dict[str, float | bool | int | str | datetime | list | set], + metrics: dict[str, float], + add_tags: dict[str, list[str] | set[str]], + remove_tags: dict[str, list[str] | set[str]], + ): + self._step = None if step is None else make_step(number=step) + self._timestamp = datetime_to_proto(timestamp) + self._fields = fields + self._metrics = metrics + self._add_tags = add_tags + self._remove_tags = remove_tags + self._project = project + self._run_id = run_id + + self._was_produced: bool = False + + def __iter__(self) -> MessageBuilder: + return self + + def __next__(self) -> RunOperation: + if not self._was_produced: + self._was_produced = True + + modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()} + modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()}) + update = UpdateRunSnapshot( + step=self._step, + timestamp=self._timestamp, + assign={key: make_value(value) for key, value in self._fields.items()}, + append={key: make_value(value) for key, value in self._metrics.items()}, + modify_sets=modify_sets, + ) + + return RunOperation(project=self._project, run_id=self._run_id, update=update) + + raise StopIteration + + +def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet: + mod_set = ModifySet() + if add is not None: + for tag in add: + mod_set.string.values[tag] = SET_OPERATION.ADD + if remove is not None: + for tag in remove: + mod_set.string.values[tag] = SET_OPERATION.REMOVE + return mod_set + + +def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: + if isinstance(value, Value): + return value + if isinstance(value, float): + return Value(float64=value) + elif isinstance(value, bool): + return Value(bool=value) + elif isinstance(value, int): + return Value(int64=value) + elif isinstance(value, str): + return Value(string=value) + elif isinstance(value, datetime): + return Value(timestamp=datetime_to_proto(value)) + elif isinstance(value, (list, set)): + fv = Value(string_set=StringSet(values=value)) + return fv + else: + raise ValueError(f"Unsupported ingest field value type: {type(value)}") + + +def datetime_to_proto(dt: datetime) -> Timestamp: + dt_ts = dt.timestamp() + return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) + + +def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: + """ + Converts a number to protobuf Step value. Example: + >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) + Args: + number: step expressed as number + raise_on_step_precision_loss: inform converter whether it should silently drop precision and + round down to 6 decimal places or raise an error. + + Returns: Step protobuf used in Neptune API. + """ + m = int(1e6) + micro: int = int(number * m) + if raise_on_step_precision_loss and number * m - micro != 0: + raise ValueError(f"step must not use more than 6-decimal points, got: {number}") + + whole = micro // m + micro = micro % m + + return Step(whole=whole, micro=micro) diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py index efd8716a..95a732e1 100644 --- a/src/neptune_scale/core/validation.py +++ b/src/neptune_scale/core/validation.py @@ -1,21 +1,21 @@ +from __future__ import annotations + __all__ = ( "verify_type", "verify_non_empty", "verify_max_length", "verify_project_qualified_name", + "verify_collection_type", ) -from typing import ( - Any, - Union, -) +from typing import Any -def get_type_name(var_type: Union[type, tuple]) -> str: +def get_type_name(var_type: type | tuple) -> str: return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) -def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None: +def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None: try: if isinstance(expected_type, tuple): type_name = " or ".join(get_type_name(t) for t in expected_type) @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None: project_parts = var.split("/") if len(project_parts) != 2: raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") + + +def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None: + verify_type(var_name, var, (list, set, tuple)) + + for value in var: + verify_type(f"elements of collection '{var_name}'", value, expected_type) diff --git a/tests/unit/test_message_builder.py b/tests/unit/test_message_builder.py new file mode 100644 index 00000000..5d78ed6d --- /dev/null +++ b/tests/unit/test_message_builder.py @@ -0,0 +1,169 @@ +from datetime import datetime + +from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + ModifyStringSet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.message_builder import MessageBuilder + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_empty(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934)), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_fields(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={ + "some/string": "value", + "some/int": 2501, + "some/float": 3.14, + "some/bool": True, + "some/datetime": datetime.now(), + "some/tags": {"tag1", "tag2"}, + }, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + assign={ + "some/string": Value(string="value"), + "some/int": Value(int64=2501), + "some/float": Value(float64=3.14), + "some/bool": Value(bool=True), + "some/datetime": Value(timestamp=Timestamp(seconds=1722341532, nanos=21934)), + "some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_metrics(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={ + "some/metric": 3.14, + }, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + append={ + "some/metric": Value(float64=3.14), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_tags(): + # given + builder = MessageBuilder( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={ + "some/tags": {"tag1", "tag2"}, + "some/other_tags2": {"tag2", "tag3"}, + }, + remove_tags={ + "some/group_tags": {"tag0", "tag1"}, + "some/other_tags": {"tag2", "tag3"}, + }, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + modify_sets={ + "some/tags": ModifySet( + string=ModifyStringSet(values={"tag1": SET_OPERATION.ADD, "tag2": SET_OPERATION.ADD}) + ), + "some/other_tags2": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.ADD, "tag3": SET_OPERATION.ADD}) + ), + "some/group_tags": ModifySet( + string=ModifyStringSet(values={"tag0": SET_OPERATION.REMOVE, "tag1": SET_OPERATION.REMOVE}) + ), + "some/other_tags": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.REMOVE, "tag3": SET_OPERATION.REMOVE}) + ), + }, + ), + ) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index f7d8310c..0c7fb614 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime import pytest @@ -80,3 +81,86 @@ def test_invalid_project_name(): with pytest.raises(ValueError): with Run(project=project, api_token=api_token, family=family, run_id=run_id): ... + + +def test_metadata(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=1, + timestamp=datetime.now(), + fields={ + "int": 1, + "string": "test", + "float": 3.14, + "bool": True, + "datetime": datetime.now(), + }, + metrics={ + "metric": 1.0, + }, + add_tags={ + "tags": ["tag1"], + }, + remove_tags={ + "group_tags": ["tag2"], + }, + ) + + +def test_log_without_step(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +def test_log_step_float(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +def test_log_no_timestamp(): + # given + project = "workspace/project" + api_token = "API_TOKEN" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + fields={ + "int": 1, + }, + )