-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
404 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
pre-commit | ||
pytest | ||
pytest-timeout | ||
freezegun |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("datetime_to_proto", "make_step", "mod_tags") | ||
|
||
from datetime import datetime | ||
|
||
from google.protobuf.timestamp_pb2 import Timestamp | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( | ||
SET_OPERATION, | ||
ModifySet, | ||
Step, | ||
StringSet, | ||
Value, | ||
) | ||
|
||
|
||
def datetime_to_proto(dt: datetime) -> Timestamp: | ||
""" | ||
Converts datetime object to google.protobuf.Timestamp used in Neptune Ingest API | ||
Args: | ||
dt: datetime - timezone is irrelevant and will be lost to unix time | ||
Returns: protobuf Timestamp | ||
""" | ||
dt_ts = dt.timestamp() | ||
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) | ||
|
||
|
||
def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: | ||
""" | ||
Converts a number to protobuf Step value. Example: | ||
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) | ||
Args: | ||
number: step expressed as number | ||
raise_on_step_precision_loss: inform converter whether it should silently drop precision and | ||
round down to 6 decimal places or raise an error. | ||
Returns: Step protobuf used in Neptune Ingest API. | ||
""" | ||
m = int(1e6) | ||
micro: int = int(number * m) | ||
if raise_on_step_precision_loss and number * m - micro != 0: | ||
raise ValueError(f"step must not use more than 6-decimal points, got: {number}") | ||
|
||
whole = micro // m | ||
micro = micro % m | ||
|
||
return Step(whole=whole, micro=micro) | ||
|
||
|
||
def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: | ||
""" | ||
Converts python value to Ingest API Value protobuf message. | ||
Make sure not to mix float and int types in the same field, especially literals - this may | ||
cause data points to be dropped. | ||
""" | ||
if isinstance(value, Value): | ||
return value | ||
if isinstance(value, float): | ||
return Value(float64=value) | ||
elif isinstance(value, bool): | ||
return Value(bool=value) | ||
elif isinstance(value, int): | ||
return Value(int64=value) | ||
elif isinstance(value, str): | ||
return Value(string=value) | ||
elif isinstance(value, datetime): | ||
return Value(timestamp=datetime_to_proto(value)) | ||
elif isinstance(value, (list, set)): | ||
fv = Value(string_set=StringSet(values=value)) | ||
return fv | ||
else: | ||
raise ValueError(f"Unsupported ingest field value type: {type(value)}") | ||
|
||
|
||
def mod_tags(add: list[str] | set[str] | None = None, remove: list[str] | set[str] | None = None) -> ModifySet: | ||
""" | ||
Shorthand to apply string set modifications. Example: | ||
>>> from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import * | ||
... | ||
... update = UpdateRunSnapshot( | ||
... modify_sets={ | ||
... "sys/tags": mod_tags(add=["tag1", "tag2"], remove=["tag0"]) | ||
... }) | ||
Args: | ||
add: list of tags to be added | ||
remove: list tags to be removed | ||
Returns: ModifySet with changes | ||
""" | ||
mod_set = ModifySet() | ||
if add is not None: | ||
for tag in add: | ||
mod_set.string.values[tag] = SET_OPERATION.ADD | ||
if remove is not None: | ||
for tag in remove: | ||
mod_set.string.values[tag] = SET_OPERATION.REMOVE | ||
return mod_set |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("MessageBuilder",) | ||
|
||
from datetime import datetime | ||
|
||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import UpdateRunSnapshot | ||
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation | ||
|
||
from neptune_scale.api.helpers import ( | ||
datetime_to_proto, | ||
make_step, | ||
make_value, | ||
mod_tags, | ||
) | ||
|
||
|
||
class MessageBuilder: | ||
def __init__( | ||
self, | ||
*, | ||
common_message: RunOperation, | ||
step: int | float | None, | ||
timestamp: datetime, | ||
fields: dict[str, float | int | str], | ||
metrics: dict[str, float], | ||
add_tags: dict[str, list[str] | set[str]], | ||
remove_tags: dict[str, list[str] | set[str]], | ||
): | ||
# TODO: Warning instead of raise on step precision loss | ||
self._step = None if step is None else make_step(number=step, raise_on_step_precision_loss=True) | ||
self._timestamp = datetime_to_proto(timestamp) | ||
self._fields = fields | ||
self._metrics = metrics | ||
self._add_tags = add_tags | ||
self._remove_tags = remove_tags | ||
self._common_message = common_message | ||
|
||
self._ith: int = 0 | ||
|
||
def __iter__(self) -> MessageBuilder: | ||
return self | ||
|
||
def __next__(self) -> RunOperation: | ||
if self._ith == 0: | ||
self._ith += 1 | ||
|
||
operation = self._common_message | ||
update = UpdateRunSnapshot( | ||
step=self._step, | ||
timestamp=self._timestamp, | ||
assign={key: make_value(value) for key, value in self._fields.items()}, | ||
append={key: make_value(value) for key, value in self._metrics.items()}, | ||
modify_sets={key: mod_tags(add=add) for key, add in self._add_tags.items()}, | ||
) | ||
remove_tags = UpdateRunSnapshot( | ||
modify_sets={key: mod_tags(remove=remove) for key, remove in self._remove_tags.items()} | ||
) | ||
update.MergeFrom(remove_tags) | ||
operation.MergeFrom(RunOperation(update=update)) | ||
return operation | ||
raise StopIteration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.