Skip to content

Commit

Permalink
Minimal working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Jul 30, 2024
1 parent b85f1c5 commit 8013147
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 6 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
pre-commit
pytest
pytest-timeout
freezegun
60 changes: 60 additions & 0 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

import threading
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable

from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.message_builder import MessageBuilder
from neptune_scale.core.validation import (
verify_collection_type,
verify_max_length,
verify_non_empty,
verify_project_qualified_name,
Expand Down Expand Up @@ -90,6 +95,61 @@ def __enter__(self) -> Run:
def resources(self) -> tuple[Resource, ...]:
return (self._operations_queue,)

def _prepare_common_message(self) -> RunOperation:
return RunOperation(
project=self._project,
run_id=self._run_id,
create_missing_project=False,
api_key=b"",
)

def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | int | str] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
) -> None:
"""
TODO: Add description
"""
verify_type("step", step, (float, int, type(None)))
verify_type("timestamp", timestamp, (datetime, type(None)))
verify_type("fields", fields, (dict, type(None)))
verify_type("metrics", metrics, (dict, type(None)))
verify_type("add_tags", add_tags, (dict, type(None)))
verify_type("remove_tags", remove_tags, (dict, type(None)))

timestamp = datetime.now() if timestamp is None else timestamp
fields = {} if fields is None else fields
metrics = {} if metrics is None else metrics
add_tags = {} if add_tags is None else add_tags
remove_tags = {} if remove_tags is None else remove_tags

verify_collection_type("`fields` keys", list(fields.keys()), str)
verify_collection_type("`metrics` keys", list(metrics.keys()), str)
verify_collection_type("`add_tags` keys", list(add_tags.keys()), str)
verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str)

# TODO: More types for values
verify_collection_type("`fields` values", list(fields.values()), (float, int, str))
verify_collection_type("`metrics` values", list(metrics.values()), float)
verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set))
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))

for operation in MessageBuilder(
common_message=self._prepare_common_message(),
step=step,
timestamp=timestamp,
fields=fields,
metrics=metrics,
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
Expand Down
Empty file.
97 changes: 97 additions & 0 deletions src/neptune_scale/api/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

__all__ = ("datetime_to_proto", "make_step", "mod_tags")

from datetime import datetime

from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
ModifySet,
Step,
StringSet,
Value,
)


def datetime_to_proto(dt: datetime) -> Timestamp:
"""
Converts datetime object to google.protobuf.Timestamp used in Neptune Ingest API
Args:
dt: datetime - timezone is irrelevant and will be lost to unix time
Returns: protobuf Timestamp
"""
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.
Returns: Step protobuf used in Neptune Ingest API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)


def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value:
"""
Converts python value to Ingest API Value protobuf message.
Make sure not to mix float and int types in the same field, especially literals - this may
cause data points to be dropped.
"""
if isinstance(value, Value):
return value
if isinstance(value, float):
return Value(float64=value)
elif isinstance(value, bool):
return Value(bool=value)
elif isinstance(value, int):
return Value(int64=value)
elif isinstance(value, str):
return Value(string=value)
elif isinstance(value, datetime):
return Value(timestamp=datetime_to_proto(value))
elif isinstance(value, (list, set)):
fv = Value(string_set=StringSet(values=value))
return fv
else:
raise ValueError(f"Unsupported ingest field value type: {type(value)}")


def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet:
"""
Shorthand to apply string set modifications. Example:
>>> from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import *
...
... update = UpdateRunSnapshot(
... modify_sets={
... "sys/tags": mod_tags(add=["tag1", "tag2"], remove=["tag0"])
... })
Args:
add: list of tags to be added
remove: list tags to be removed
Returns: ModifySet with changes
"""
mod_set = ModifySet()
if add is not None:
for tag in add:
mod_set.string.values[tag] = SET_OPERATION.ADD
if remove is not None:
for tag in remove:
mod_set.string.values[tag] = SET_OPERATION.REMOVE
return mod_set
1 change: 1 addition & 0 deletions src/neptune_scale/core/components/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(

def enqueue(self, *, operation: RunOperation) -> None:
try:
# TODO: This lock could be moved to the Run class
with self._lock:
serialized_operation = operation.SerializeToString()

Expand Down
62 changes: 62 additions & 0 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

__all__ = ("MessageBuilder",)

from datetime import datetime

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import UpdateRunSnapshot
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.helpers import (
datetime_to_proto,
make_step,
make_value,
mod_tags,
)


class MessageBuilder:
def __init__(
self,
*,
common_message: RunOperation,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | int | str],
metrics: dict[str, float],
add_tags: dict[str, list[str] | set[str]],
remove_tags: dict[str, list[str] | set[str]],
):
# TODO: Warning instead of raise on step precision loss
self._step = None if step is None else make_step(number=step, raise_on_step_precision_loss=True)
self._timestamp = datetime_to_proto(timestamp)
self._fields = fields
self._metrics = metrics
self._add_tags = add_tags
self._remove_tags = remove_tags
self._common_message = common_message

self._ith: int = 0

def __iter__(self) -> MessageBuilder:
return self

def __next__(self) -> RunOperation:
if self._ith == 0:
self._ith += 1

operation = self._common_message
update = UpdateRunSnapshot(
step=self._step,
timestamp=self._timestamp,
assign={key: make_value(value) for key, value in self._fields.items()},
append={key: make_value(value) for key, value in self._metrics.items()},
modify_sets={key: mod_tags(add=add) for key, add in self._add_tags.items()},
)
remove_tags = UpdateRunSnapshot(
modify_sets={key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()}
)
update.MergeFrom(remove_tags)
operation.MergeFrom(RunOperation(update=update))
return operation
raise StopIteration
19 changes: 13 additions & 6 deletions src/neptune_scale/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from __future__ import annotations

__all__ = (
"verify_type",
"verify_non_empty",
"verify_max_length",
"verify_project_qualified_name",
"verify_collection_type",
)

from typing import (
Any,
Union,
)
from typing import Any


def get_type_name(var_type: Union[type, tuple]) -> str:
def get_type_name(var_type: type | tuple) -> str:
return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type)


def verify_type(var_name: str, var: Any, expected_type: Union[type, tuple]) -> None:
def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None:
try:
if isinstance(expected_type, tuple):
type_name = " or ".join(get_type_name(t) for t in expected_type)
Expand Down Expand Up @@ -46,3 +46,10 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None:
project_parts = var.split("/")
if len(project_parts) != 2:
raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name")


def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None:
verify_type(var_name, var, (list, set, tuple))

for value in var:
verify_type(f"elements of collection '{var_name}'", value, expected_type)
Loading

0 comments on commit 8013147

Please sign in to comment.