Skip to content

Commit

Permalink
Merge pull request #104 from neptune-ai/kg/raise-on-non-increasing-step
Browse files Browse the repository at this point in the history
Raise on logging non-increasing series step
  • Loading branch information
kgodlewski authored Jan 13, 2025
2 parents 3a05c3a + e9855e1 commit 4c2e36e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 17 deletions.
54 changes: 42 additions & 12 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import threading
import warnings
from collections.abc import (
Collection,
Expand All @@ -14,6 +15,7 @@
cast,
)

from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue

Expand Down Expand Up @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue)
self._run_id = run_id
self._operations_queue = operations_queue
self._attributes: dict[str, Attribute] = {}
# Keep a list of path -> (last step, last value) mappings to detect non-increasing steps
# at call site. The backend will detect this error as well, but it's more convenient for the user
# to get the error as soon as possible.
self._metric_state: dict[str, tuple[float, float]] = {}
self._lock = threading.RLock()

def __getitem__(self, path: str) -> "Attribute":
path = cleanup_path(path)
Expand All @@ -85,22 +92,45 @@ def log(
) -> None:
if timestamp is None:
timestamp = datetime.now()
elif isinstance(timestamp, float):
elif isinstance(timestamp, (float, int)):
timestamp = datetime.fromtimestamp(timestamp)

splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
# MetadataSplitter is an iterator, so gather everything into a list instead of iterating over
# it in the critical section, to avoid holding the lock for too long.
# TODO: Move splitting into the worker process. Here we should just send messages as they are.
chunks = list(
MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
)
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size)
with self._lock:
self._verify_and_update_metrics_state(step, metrics)

for operation, metadata_size in chunks:
self._operations_queue.enqueue(operation=operation, size=metadata_size)

def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[dict[str, float]]) -> None:
"""Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not."""

if step is None or metrics is None:
return

for metric, value in metrics.items():
if (state := self._metric_state.get(metric)) is not None:
last_step, last_value = state
# Repeating a step is fine as long as the value does not change
if step < last_step or (step == last_step and value != last_value):
raise NeptuneSeriesStepNonIncreasing()

self._metric_state[metric] = (step, value)


class Attribute:
Expand Down
2 changes: 2 additions & 0 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:

def work(self) -> None:
try:
# TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes,
# so we could just pass around instances of RunOperation
while (operation := self.get_next()) is not None:
sequence_id, timestamp, data = operation

Expand Down
32 changes: 27 additions & 5 deletions tests/unit/test_attribute.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from datetime import datetime
from unittest.mock import Mock

Expand All @@ -8,13 +9,15 @@
)

from neptune_scale.api.attribute import cleanup_path
from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing
from neptune_scale.legacy import Run


@fixture
def run(api_token):
run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token)
run._attr_store.log = Mock()
# Mock log to be able to assert calls, but also proxy to the actual method so it does its job
run._attr_store.log = Mock(side_effect=run._attr_store.log)
with run:
yield run

Expand Down Expand Up @@ -67,11 +70,30 @@ def test_tags(run, store):


def test_series(run, store):
run["sys/series"].append(1, step=1, timestamp=10)
store.log.assert_called_with(metrics={"sys/series": 1}, step=1, timestamp=10)
now = time.time()
run["my/series"].append(1, step=1, timestamp=now)
store.log.assert_called_with(metrics={"my/series": 1}, step=1, timestamp=now)

run["sys/series"].append({"foo": 1, "bar": 2}, step=2)
store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None)
run["my/series"].append({"foo": 1, "bar": 2}, step=2)
store.log.assert_called_with(metrics={"my/series/foo": 1, "my/series/bar": 2}, step=2, timestamp=None)


def test_error_on_non_increasing_step(run):
run["series"].append(1, step=2)

# Step lower than previous
with pytest.raises(NeptuneSeriesStepNonIncreasing):
run["series"].append(2, step=1)

# Equal to previous, but different value
with pytest.raises(NeptuneSeriesStepNonIncreasing):
run["series"].append(3, step=2)

# Equal to previous, same value -> should pass
run["series"].append(1, step=2)

# None should pass, as it means auto-increment
run["series"].append(4, step=None)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4c2e36e

Please sign in to comment.