Skip to content

Commit

Permalink
Allow step to be None.
Browse files Browse the repository at this point in the history
Validate if steps are increasing client-side for early error detection.
  • Loading branch information
kgodlewski committed Dec 2, 2024
1 parent 0c26195 commit edc51e4
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 50 deletions.
54 changes: 40 additions & 14 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import threading
import warnings
from datetime import datetime
from typing import (
Expand All @@ -16,6 +17,7 @@
cast,
)

from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue

Expand Down Expand Up @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue)
self._run_id = run_id
self._operations_queue = operations_queue
self._attributes: Dict[str, Attribute] = {}
# Keep a list of path -> (last step, last value) mappings to detect non-increasing steps
# at call site. The backend will detect this error as well, but it's more convenient for the user
# to get the error as soon as possible.
self._metric_state: Dict[str, Tuple[float, float]] = {}
self._lock = threading.RLock()

def __getitem__(self, key: str) -> "Attribute":
path = cleanup_path(key)
Expand All @@ -85,22 +92,41 @@ def log(
) -> None:
if timestamp is None:
timestamp = datetime.now()
elif isinstance(timestamp, float):
elif isinstance(timestamp, (float, int)):
timestamp = datetime.fromtimestamp(timestamp)

splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
)
with self._lock:
self._verify_and_update_metrics_state(step, metrics)

# TODO: Move splitting into the worker process. Here we should just send messages as they are.
splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, step=step)

def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[Dict[str, float]]) -> None:
"""Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not."""

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step)
if step is None or metrics is None:
return

for metric, value in metrics.items():
if (state := self._metric_state.get(metric)) is not None:
last_step, last_value = state
# Repeating a step is fine as long as the value does not change
if step < last_step or (step == last_step and value != last_value):
raise NeptuneSeriesStepNonIncreasing()

self._metric_state[metric] = (step, value)


class Attribute:
Expand Down Expand Up @@ -128,7 +154,7 @@ def append(
self,
value: Union[Dict[str, Any], float],
*,
step: Union[float, int],
step: Optional[Union[float, int]] = None,
timestamp: Optional[Union[float, datetime]] = None,
wait: bool = False,
**kwargs: Any,
Expand Down
2 changes: 1 addition & 1 deletion src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __setitem__(self, key: str, value: Any) -> None:
def log_metrics(
self,
data: Dict[str, Union[float, int]],
step: Union[float, int],
step: Optional[Union[float, int]],
*,
timestamp: Optional[datetime] = None,
) -> None:
Expand Down
25 changes: 21 additions & 4 deletions src/neptune_scale/sync/aggregating_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get(self) -> BatchedOperations:
start = time.monotonic()

batch_operations: list[RunOperation] = []
last_batch_key: Optional[float] = None
current_batch_step: Optional[float] = None
batch_sequence_id: Optional[int] = None
batch_timestamp: Optional[float] = None

Expand All @@ -97,7 +97,7 @@ def get(self) -> BatchedOperations:
new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
batch_operations.append(new_operation)
last_batch_key = element.batch_key
current_batch_step = element.step
batch_bytes += len(element.operation)
else:
if not element.is_batchable:
Expand All @@ -112,9 +112,26 @@ def get(self) -> BatchedOperations:

new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
if element.batch_key != last_batch_key:

# This is where we decide if we need to wrap up the current UpdateSnapshot and start a new one.
# This happens if the step changes, but also if it is None.
# On None, the backend will assign the next available step. This is why we cannot merge here,
# especially considering metrics, since we would overwrite them:
#
# log metric1=1.0, step=None
# log metric1=1.2, step=None
#
# After merging by step, we would end up with a single value (the most recent one).
#
# TODO: we could potentially keep merging until we encounter a metric already seen in this batch.
# Something to optimize in the future. Given the metrics:
# m1, m2, m3, m4, m1, m2, m3, ...
# we could batch up to m4 and close the batch when encountering m1, as long as steps are None
# We could also keep batching if there are no metrics in a given operation, although this would
# not be a common case.
if element.step is None or element.step != current_batch_step:
batch_operations.append(new_operation)
last_batch_key = element.batch_key
current_batch_step = element.step
else:
merge_run_operation(batch_operations[-1], new_operation)
batch_bytes += element.metadata_size
Expand Down
4 changes: 2 additions & 2 deletions src/neptune_scale/sync/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]:
with self._lock:
return self._last_timestamp

def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None:
def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, step: Optional[float] = None) -> None:
try:
is_metadata_update = operation.HasField("update")
serialized_operation = operation.SerializeToString()
Expand All @@ -75,7 +75,7 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O
operation=serialized_operation,
metadata_size=size,
is_batchable=is_metadata_update,
batch_key=key,
step=step,
),
block=True,
timeout=None,
Expand Down
4 changes: 2 additions & 2 deletions src/neptune_scale/sync/queue_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ class SingleOperation(NamedTuple):
is_batchable: bool
# Size of the metadata in the operation (without project, family, run_id etc.)
metadata_size: Optional[int]
# Update metadata key
batch_key: Optional[float]
# Step provided by the user
step: Optional[float]
2 changes: 2 additions & 0 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:

def work(self) -> None:
try:
# TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes,
# so we could just pass around instances of RunOperation
while (operation := self.get_next()) is not None:
sequence_id, timestamp, data = operation

Expand Down
Loading

0 comments on commit edc51e4

Please sign in to comment.