Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging metadata #8

Merged
merged 11 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6))
- Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7))
- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8))
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
pre-commit
pytest
pytest-timeout
freezegun
67 changes: 67 additions & 0 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

import threading
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable

from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.message_builder import MessageBuilder
from neptune_scale.core.validation import (
verify_collection_type,
verify_max_length,
verify_non_empty,
verify_project_qualified_name,
Expand Down Expand Up @@ -95,3 +98,67 @@ def close(self) -> None:
Stops the connection to Neptune and synchronizes all data.
"""
super().close()

def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | bool | int | str | datetime | list | set] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
) -> None:
"""
Logs the specified metadata to Neptune.

Args:
step: Index of the log entry, must be increasing. If None, the highest of the already logged indexes is used.
timestamp: Time of logging the metadata.
fields: Dictionary of fields to log.
metrics: Dictionary of metrics to log.
add_tags: Dictionary of tags to add to the run.
remove_tags: Dictionary of tags to remove from the run.

Examples:
```
>>> with Run(...) as run:
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
... run.log(step=1, fields={"parameters/learning_rate": 0.001})
... run.log(step=2, add_tags={"sys/group_tags": ["group1", "group2"]})
... run.log(step=3, metrics={"metrics/loss": 0.1})
```

"""
verify_type("step", step, (float, int, type(None)))
verify_type("timestamp", timestamp, (datetime, type(None)))
verify_type("fields", fields, (dict, type(None)))
verify_type("metrics", metrics, (dict, type(None)))
verify_type("add_tags", add_tags, (dict, type(None)))
verify_type("remove_tags", remove_tags, (dict, type(None)))

timestamp = datetime.now() if timestamp is None else timestamp
fields = {} if fields is None else fields
metrics = {} if metrics is None else metrics
add_tags = {} if add_tags is None else add_tags
remove_tags = {} if remove_tags is None else remove_tags

verify_collection_type("`fields` keys", list(fields.keys()), str)
verify_collection_type("`metrics` keys", list(metrics.keys()), str)
verify_collection_type("`add_tags` keys", list(add_tags.keys()), str)
verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str)

verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set))
verify_collection_type("`metrics` values", list(metrics.values()), float)
verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set))
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))

for operation in MessageBuilder(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
fields=fields,
metrics=metrics,
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)
1 change: 1 addition & 0 deletions src/neptune_scale/core/components/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(

def enqueue(self, *, operation: RunOperation) -> None:
try:
# TODO: This lock could be moved to the Run class
with self._lock:
serialized_operation = operation.SerializeToString()

Expand Down
120 changes: 120 additions & 0 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

__all__ = ("MessageBuilder",)

from datetime import datetime

from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
ModifySet,
Step,
StringSet,
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation


class MessageBuilder:
def __init__(
self,
*,
project: str,
run_id: str,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | bool | int | str | datetime | list | set],
metrics: dict[str, float],
add_tags: dict[str, list[str] | set[str]],
remove_tags: dict[str, list[str] | set[str]],
):
self._step = None if step is None else make_step(number=step)
self._timestamp = datetime_to_proto(timestamp)
self._fields = fields
self._metrics = metrics
self._add_tags = add_tags
self._remove_tags = remove_tags
self._project = project
self._run_id = run_id

self._was_produced: bool = False

def __iter__(self) -> MessageBuilder:
return self

def __next__(self) -> RunOperation:
if not self._was_produced:
self._was_produced = True

modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()}
modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()})
update = UpdateRunSnapshot(
step=self._step,
timestamp=self._timestamp,
assign={key: make_value(value) for key, value in self._fields.items()},
append={key: make_value(value) for key, value in self._metrics.items()},
modify_sets=modify_sets,
)

return RunOperation(project=self._project, run_id=self._run_id, update=update)

raise StopIteration


def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet:
mod_set = ModifySet()
if add is not None:
for tag in add:
mod_set.string.values[tag] = SET_OPERATION.ADD
if remove is not None:
for tag in remove:
mod_set.string.values[tag] = SET_OPERATION.REMOVE
return mod_set


def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value:
if isinstance(value, Value):
return value
if isinstance(value, float):
return Value(float64=value)
elif isinstance(value, bool):
return Value(bool=value)
elif isinstance(value, int):
return Value(int64=value)
elif isinstance(value, str):
return Value(string=value)
elif isinstance(value, datetime):
return Value(timestamp=datetime_to_proto(value))
elif isinstance(value, (list, set)):
fv = Value(string_set=StringSet(values=value))
return fv
else:
raise ValueError(f"Unsupported ingest field value type: {type(value)}")


def datetime_to_proto(dt: datetime) -> Timestamp:
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.

Returns: Step protobuf used in Neptune API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)
19 changes: 13 additions & 6 deletions src/neptune_scale/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from __future__ import annotations

__all__ = (
"verify_type",
"verify_non_empty",
"verify_max_length",
"verify_project_qualified_name",
"verify_collection_type",
)

from typing import (
Any,
Union,
)
from typing import Any


def get_type_name(var_type: Union[type, tuple]) -> str:
def get_type_name(var_type: type | tuple) -> str:
return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type)


def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None:
def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None:
try:
if isinstance(expected_type, tuple):
type_name = " or ".join(get_type_name(t) for t in expected_type)
Expand Down Expand Up @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None:
project_parts = var.split("/")
if len(project_parts) != 2:
raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name")


def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None:
verify_type(var_name, var, (list, set, tuple))

for value in var:
verify_type(f"elements of collection '{var_name}'", value, expected_type)
Loading
Loading