Skip to content

Commit

Permalink
Logging metadata (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Jul 31, 2024
1 parent b85f1c5 commit 2faf3a5
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6))
- Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7))
- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8))
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
pre-commit
pytest
pytest-timeout
freezegun
67 changes: 67 additions & 0 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

import threading
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable

from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.message_builder import MessageBuilder
from neptune_scale.core.validation import (
verify_collection_type,
verify_max_length,
verify_non_empty,
verify_project_qualified_name,
Expand Down Expand Up @@ -95,3 +98,67 @@ def close(self) -> None:
Stops the connection to Neptune and synchronizes all data.
"""
super().close()

def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | bool | int | str | datetime | list | set] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
) -> None:
"""
Logs the specified metadata to Neptune.
Args:
step: Index of the log entry, must be increasing. If None, the highest of the already logged indexes is used.
timestamp: Time of logging the metadata.
fields: Dictionary of fields to log.
metrics: Dictionary of metrics to log.
add_tags: Dictionary of tags to add to the run.
remove_tags: Dictionary of tags to remove from the run.
Examples:
```
>>> with Run(...) as run:
... run.log(step=1, fields={"parameters/learning_rate": 0.001})
... run.log(step=2, add_tags={"sys/group_tags": ["group1", "group2"]})
... run.log(step=3, metrics={"metrics/loss": 0.1})
```
"""
verify_type("step", step, (float, int, type(None)))
verify_type("timestamp", timestamp, (datetime, type(None)))
verify_type("fields", fields, (dict, type(None)))
verify_type("metrics", metrics, (dict, type(None)))
verify_type("add_tags", add_tags, (dict, type(None)))
verify_type("remove_tags", remove_tags, (dict, type(None)))

timestamp = datetime.now() if timestamp is None else timestamp
fields = {} if fields is None else fields
metrics = {} if metrics is None else metrics
add_tags = {} if add_tags is None else add_tags
remove_tags = {} if remove_tags is None else remove_tags

verify_collection_type("`fields` keys", list(fields.keys()), str)
verify_collection_type("`metrics` keys", list(metrics.keys()), str)
verify_collection_type("`add_tags` keys", list(add_tags.keys()), str)
verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str)

verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set))
verify_collection_type("`metrics` values", list(metrics.values()), float)
verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set))
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))

for operation in MessageBuilder(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
fields=fields,
metrics=metrics,
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)
1 change: 1 addition & 0 deletions src/neptune_scale/core/components/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(

def enqueue(self, *, operation: RunOperation) -> None:
try:
# TODO: This lock could be moved to the Run class
with self._lock:
serialized_operation = operation.SerializeToString()

Expand Down
120 changes: 120 additions & 0 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

__all__ = ("MessageBuilder",)

from datetime import datetime

from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
ModifySet,
Step,
StringSet,
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation


class MessageBuilder:
def __init__(
self,
*,
project: str,
run_id: str,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | bool | int | str | datetime | list | set],
metrics: dict[str, float],
add_tags: dict[str, list[str] | set[str]],
remove_tags: dict[str, list[str] | set[str]],
):
self._step = None if step is None else make_step(number=step)
self._timestamp = datetime_to_proto(timestamp)
self._fields = fields
self._metrics = metrics
self._add_tags = add_tags
self._remove_tags = remove_tags
self._project = project
self._run_id = run_id

self._was_produced: bool = False

def __iter__(self) -> MessageBuilder:
return self

def __next__(self) -> RunOperation:
if not self._was_produced:
self._was_produced = True

modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()}
modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()})
update = UpdateRunSnapshot(
step=self._step,
timestamp=self._timestamp,
assign={key: make_value(value) for key, value in self._fields.items()},
append={key: make_value(value) for key, value in self._metrics.items()},
modify_sets=modify_sets,
)

return RunOperation(project=self._project, run_id=self._run_id, update=update)

raise StopIteration


def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet:
mod_set = ModifySet()
if add is not None:
for tag in add:
mod_set.string.values[tag] = SET_OPERATION.ADD
if remove is not None:
for tag in remove:
mod_set.string.values[tag] = SET_OPERATION.REMOVE
return mod_set


def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value:
if isinstance(value, Value):
return value
if isinstance(value, float):
return Value(float64=value)
elif isinstance(value, bool):
return Value(bool=value)
elif isinstance(value, int):
return Value(int64=value)
elif isinstance(value, str):
return Value(string=value)
elif isinstance(value, datetime):
return Value(timestamp=datetime_to_proto(value))
elif isinstance(value, (list, set)):
fv = Value(string_set=StringSet(values=value))
return fv
else:
raise ValueError(f"Unsupported ingest field value type: {type(value)}")


def datetime_to_proto(dt: datetime) -> Timestamp:
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.
Returns: Step protobuf used in Neptune API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)
19 changes: 13 additions & 6 deletions src/neptune_scale/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from __future__ import annotations

__all__ = (
"verify_type",
"verify_non_empty",
"verify_max_length",
"verify_project_qualified_name",
"verify_collection_type",
)

from typing import (
Any,
Union,
)
from typing import Any


def get_type_name(var_type: Union[type, tuple]) -> str:
def get_type_name(var_type: type | tuple) -> str:
return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type)


def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None:
def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None:
try:
if isinstance(expected_type, tuple):
type_name = " or ".join(get_type_name(t) for t in expected_type)
Expand Down Expand Up @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None:
project_parts = var.split("/")
if len(project_parts) != 2:
raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name")


def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None:
verify_type(var_name, var, (list, set, tuple))

for value in var:
verify_type(f"elements of collection '{var_name}'", value, expected_type)
Loading

0 comments on commit 2faf3a5

Please sign in to comment.