diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index 734bf3e..e3dbbf7 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -1,5 +1,6 @@ import functools import itertools +import threading import warnings from collections.abc import ( Collection, @@ -14,6 +15,7 @@ cast, ) +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) self._run_id = run_id self._operations_queue = operations_queue self._attributes: dict[str, Attribute] = {} + # Keep a list of path -> (last step, last value) mappings to detect non-increasing steps + # at call site. The backend will detect this error as well, but it's more convenient for the user + # to get the error as soon as possible. + self._metric_state: dict[str, tuple[float, float]] = {} + self._lock = threading.RLock() def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -85,22 +92,45 @@ def log( ) -> None: if timestamp is None: timestamp = datetime.now() - elif isinstance(timestamp, float): + elif isinstance(timestamp, (float, int)): timestamp = datetime.fromtimestamp(timestamp) - splitter: MetadataSplitter = MetadataSplitter( - project=self._project, - run_id=self._run_id, - step=step, - timestamp=timestamp, - configs=configs, - metrics=metrics, - add_tags=tags_add, - remove_tags=tags_remove, + # MetadataSplitter is an iterator, so gather everything into a list instead of iterating over + # it in the critical section, to avoid holding the lock for too long. + # TODO: Move splitting into the worker process. Here we should just send messages as they are. + chunks = list( + MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + configs=configs, + metrics=metrics, + add_tags=tags_add, + remove_tags=tags_remove, + ) ) - for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size) + with self._lock: + self._verify_and_update_metrics_state(step, metrics) + + for operation, metadata_size in chunks: + self._operations_queue.enqueue(operation=operation, size=metadata_size) + + def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[dict[str, float]]) -> None: + """Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not.""" + + if step is None or metrics is None: + return + + for metric, value in metrics.items(): + if (state := self._metric_state.get(metric)) is not None: + last_step, last_value = state + # Repeating a step is fine as long as the value does not change + if step < last_step or (step == last_step and value != last_value): + raise NeptuneSeriesStepNonIncreasing() + + self._metric_state[metric] = (step, value) class Attribute: diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index 098be7b..8e655ad 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -429,6 +429,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: def work(self) -> None: try: + # TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes, + # so we could just pass around instances of RunOperation while (operation := self.get_next()) is not None: sequence_id, timestamp, data = operation diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index 481651c..072aa84 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from unittest.mock import Mock @@ -8,13 +9,15 @@ ) from neptune_scale.api.attribute import cleanup_path +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.legacy import Run @fixture def run(api_token): run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token) - run._attr_store.log = Mock() + # Mock log to be able to assert calls, but also proxy to the actual method so it does its job + run._attr_store.log = Mock(side_effect=run._attr_store.log) with run: yield run @@ -67,11 +70,30 @@ def test_tags(run, store): def test_series(run, store): - run["sys/series"].append(1, step=1, timestamp=10) - store.log.assert_called_with(metrics={"sys/series": 1}, step=1, timestamp=10) + now = time.time() + run["my/series"].append(1, step=1, timestamp=now) + store.log.assert_called_with(metrics={"my/series": 1}, step=1, timestamp=now) - run["sys/series"].append({"foo": 1, "bar": 2}, step=2) - store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None) + run["my/series"].append({"foo": 1, "bar": 2}, step=2) + store.log.assert_called_with(metrics={"my/series/foo": 1, "my/series/bar": 2}, step=2, timestamp=None) + + +def test_error_on_non_increasing_step(run): + run["series"].append(1, step=2) + + # Step lower than previous + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(2, step=1) + + # Equal to previous, but different value + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(3, step=2) + + # Equal to previous, same value -> should pass + run["series"].append(1, step=2) + + # None should pass, as it means auto-increment + run["series"].append(4, step=None) @pytest.mark.parametrize(