Skip to content

Commit

Permalink
Final
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Aug 9, 2024
1 parent 4cf332e commit c6e204f
Show file tree
Hide file tree
Showing 6 changed files with 430 additions and 945 deletions.
7 changes: 2 additions & 5 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.metadata_splitters import (
MetadataSplitter,
NoSplitting,
)
from neptune_scale.core.metadata_splitter import MetadataSplitter
from neptune_scale.core.serialization import (
datetime_to_proto,
make_step,
Expand Down Expand Up @@ -245,7 +242,7 @@ def log(
verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set))
verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set))

splitter: MetadataSplitter = NoSplitting(
splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
Expand Down
161 changes: 161 additions & 0 deletions src/neptune_scale/core/metadata_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

__all__ = ("MetadataSplitter",)

from datetime import datetime
from typing import (
Any,
Callable,
Iterator,
TypeVar,
)

from more_itertools import peekable
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.serialization import (
datetime_to_proto,
make_step,
make_value,
pb_key_size,
)

T = TypeVar("T", bound=Any)


class MetadataSplitter(Iterator[RunOperation]):
def __init__(
self,
*,
project: str,
run_id: str,
step: int | float | None,
timestamp: datetime,
fields: dict[str, float | bool | int | str | datetime | list | set],
metrics: dict[str, float],
add_tags: dict[str, list[str] | set[str]],
remove_tags: dict[str, list[str] | set[str]],
max_message_bytes_size: int = 1024 * 1024,
):
self._step = None if step is None else make_step(number=step)
self._timestamp = datetime_to_proto(timestamp)
self._project = project
self._run_id = run_id
self._fields = peekable(fields.items())
self._metrics = peekable(metrics.items())
self._add_tags = peekable(add_tags.items())
self._remove_tags = peekable(remove_tags.items())

self._max_update_bytes_size = (
max_message_bytes_size
- RunOperation(
project=self._project,
run_id=self._run_id,
update=UpdateRunSnapshot(step=self._step, timestamp=self._timestamp),
).ByteSize()
)

self._has_returned = False

def __iter__(self) -> MetadataSplitter:
self._has_returned = False
return self

def __next__(self) -> RunOperation:
size = 0
update = UpdateRunSnapshot(
step=self._step,
timestamp=self._timestamp,
assign={},
append={},
modify_sets={},
)

size = self.populate(
assets=self._fields,
update_producer=lambda key, value: update.assign[key].MergeFrom(value),
size=size,
)
size = self.populate(
assets=self._metrics,
update_producer=lambda key, value: update.append[key].MergeFrom(value),
size=size,
)
size = self.populate_tags(
update=update,
assets=self._add_tags,
operation=SET_OPERATION.ADD,
size=size,
)
_ = self.populate_tags(
update=update,
assets=self._remove_tags,
operation=SET_OPERATION.REMOVE,
size=size,
)

if not self._has_returned or update.assign or update.append or update.modify_sets:
self._has_returned = True
return RunOperation(project=self._project, run_id=self._run_id, update=update)
else:
raise StopIteration

def populate(
self,
assets: peekable[Any],
update_producer: Callable[[str, Value], None],
size: int,
) -> int:
while size < self._max_update_bytes_size:
try:
key, value = assets.peek()
except StopIteration:
break

proto_value = make_value(value)
new_size = size + pb_key_size(key) + proto_value.ByteSize() + 6

if new_size > self._max_update_bytes_size:
break

update_producer(key, proto_value)
size, _ = new_size, next(assets)

return size

def populate_tags(
self, update: UpdateRunSnapshot, assets: peekable[Any], operation: SET_OPERATION.ValueType, size: int
) -> int:
while size < self._max_update_bytes_size:
try:
key, values = assets.peek()
except StopIteration:
break

if not isinstance(values, peekable):
values = peekable(values)

is_full = False
new_size = size + pb_key_size(key) + 6
for value in values:
tag_size = pb_key_size(value) + 6
if new_size + tag_size > self._max_update_bytes_size:
values.prepend(value)
is_full = True
break

update.modify_sets[key].string.values[value] = operation
new_size += tag_size

size, _ = new_size, next(assets)

if is_full:
assets.prepend((key, list(values)))
break

return size
Loading

0 comments on commit c6e204f

Please sign in to comment.