Skip to content

Commit

Permalink
Tests added and minimal no-splitting implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Jul 30, 2024
1 parent 86cbd39 commit 6286c45
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 126 deletions.
5 changes: 2 additions & 3 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | int | str] | None = None,
fields: dict[str, float | bool | int | str | datetime | list | set] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
Expand All @@ -129,8 +129,7 @@ def log(
verify_collection_type("`add_tags` keys", list(add_tags.keys()), str)
verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str)

# TODO: More types for values
verify_collection_type("`fields` values", list(fields.values()), (float, int, str))
verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set))
verify_collection_type("`metrics` values", list(metrics.values()), float)
verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set))
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))
Expand Down
Empty file removed src/neptune_scale/api/__init__.py
Empty file.
97 changes: 0 additions & 97 deletions src/neptune_scale/api/helpers.py

This file was deleted.

103 changes: 79 additions & 24 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

from datetime import datetime

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import UpdateRunSnapshot
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.helpers import (
datetime_to_proto,
make_step,
make_value,
mod_tags,
from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
ModifySet,
Step,
StringSet,
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation


class MessageBuilder:
Expand All @@ -23,7 +24,7 @@ def __init__(
run_id: str,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | int | str],
fields: dict[str, float | bool | int | str | datetime | list | set],
metrics: dict[str, float],
add_tags: dict[str, list[str] | set[str]],
remove_tags: dict[str, list[str] | set[str]],
Expand All @@ -35,32 +36,86 @@ def __init__(
self._metrics = metrics
self._add_tags = add_tags
self._remove_tags = remove_tags
self._project = project
self._run_id = run_id

self._common_message = RunOperation(
project=project,
run_id=run_id,
)
self._ith: int = 0
self._was_produced: bool = False

def __iter__(self) -> MessageBuilder:
return self

def __next__(self) -> RunOperation:
if self._ith == 0:
self._ith += 1
if not self._was_produced:
self._was_produced = True

operation = self._common_message
modify_sets = {key: mod_tags(add=add) for key, add in self._add_tags.items()}
modify_sets.update({key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()})
update = UpdateRunSnapshot(
step=self._step,
timestamp=self._timestamp,
assign={key: make_value(value) for key, value in self._fields.items()},
append={key: make_value(value) for key, value in self._metrics.items()},
modify_sets={key: mod_tags(add=add) for key, add in self._add_tags.items()},
modify_sets=modify_sets,
)
remove_tags = UpdateRunSnapshot(
modify_sets={key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()}
)
update.MergeFrom(remove_tags)
operation.MergeFrom(RunOperation(update=update))
return operation

return RunOperation(project=self._project, run_id=self._run_id, update=update)

raise StopIteration


def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet:
mod_set = ModifySet()
if add is not None:
for tag in add:
mod_set.string.values[tag] = SET_OPERATION.ADD
if remove is not None:
for tag in remove:
mod_set.string.values[tag] = SET_OPERATION.REMOVE
return mod_set


def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value:
if isinstance(value, Value):
return value
if isinstance(value, float):
return Value(float64=value)
elif isinstance(value, bool):
return Value(bool=value)
elif isinstance(value, int):
return Value(int64=value)
elif isinstance(value, str):
return Value(string=value)
elif isinstance(value, datetime):
return Value(timestamp=datetime_to_proto(value))
elif isinstance(value, (list, set)):
fv = Value(string_set=StringSet(values=value))
return fv
else:
raise ValueError(f"Unsupported ingest field value type: {type(value)}")


def datetime_to_proto(dt: datetime) -> Timestamp:
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.
Returns: Step protobuf used in Neptune API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)
75 changes: 73 additions & 2 deletions tests/unit/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_invalid_project_name():
...


def test__metadata():
def test_metadata():
# given
project = "workspace/project"
api_token = "API_TOKEN"
Expand All @@ -92,4 +92,75 @@ def test__metadata():

# then
with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run:
run.log(step=1, timestamp=datetime.now(), fields={"test": 1})
run.log(
step=1,
timestamp=datetime.now(),
fields={
"int": 1,
"string": "test",
"float": 3.14,
"bool": True,
"datetime": datetime.now(),
},
metrics={
"metric": 1.0,
},
add_tags={
"tags": ["tag1"],
},
remove_tags={
"group_tags": ["tag2"],
},
)


def test_log_without_step():
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

# then
with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run:
run.log(
timestamp=datetime.now(),
fields={
"int": 1,
},
)


def test_log_step_float():
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

# then
with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run:
run.log(
step=3.14,
timestamp=datetime.now(),
fields={
"int": 1,
},
)


def test_log_no_timestamp():
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

# then
with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run:
run.log(
step=3.14,
fields={
"int": 1,
},
)

0 comments on commit 6286c45

Please sign in to comment.