diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b421be2a..73ef26dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.7.0b + - neptune-api - more-itertools - backoff default_language_version: diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index ea433137..e3dbbf7d 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -1,5 +1,6 @@ import functools import itertools +import threading import warnings from collections.abc import ( Collection, @@ -14,6 +15,7 @@ cast, ) +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) self._run_id = run_id self._operations_queue = operations_queue self._attributes: dict[str, Attribute] = {} + # Keep a list of path -> (last step, last value) mappings to detect non-increasing steps + # at call site. The backend will detect this error as well, but it's more convenient for the user + # to get the error as soon as possible. + self._metric_state: dict[str, tuple[float, float]] = {} + self._lock = threading.RLock() def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -85,22 +92,45 @@ def log( ) -> None: if timestamp is None: timestamp = datetime.now() - elif isinstance(timestamp, float): + elif isinstance(timestamp, (float, int)): timestamp = datetime.fromtimestamp(timestamp) - splitter: MetadataSplitter = MetadataSplitter( - project=self._project, - run_id=self._run_id, - step=step, - timestamp=timestamp, - configs=configs, - metrics=metrics, - add_tags=tags_add, - remove_tags=tags_remove, + # MetadataSplitter is an iterator, so gather everything into a list instead of iterating over + # it in the critical section, to avoid holding the lock for too long. + # TODO: Move splitting into the worker process. Here we should just send messages as they are. + chunks = list( + MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + configs=configs, + metrics=metrics, + add_tags=tags_add, + remove_tags=tags_remove, + ) ) - for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + with self._lock: + self._verify_and_update_metrics_state(step, metrics) + + for operation, metadata_size in chunks: + self._operations_queue.enqueue(operation=operation, size=metadata_size) + + def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[dict[str, float]]) -> None: + """Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not.""" + + if step is None or metrics is None: + return + + for metric, value in metrics.items(): + if (state := self._metric_state.get(metric)) is not None: + last_step, last_value = state + # Repeating a step is fine as long as the value does not change + if step < last_step or (step == last_step and value != last_value): + raise NeptuneSeriesStepNonIncreasing() + + self._metric_state[metric] = (step, value) class Attribute: diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index a6cf90c2..1bbcccfd 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -27,9 +27,11 @@ from typing import ( Any, Literal, + cast, ) import httpx +import neptune_retrieval_api.client from httpx import Timeout from neptune_api import ( AuthenticatedClient, @@ -64,6 +66,11 @@ from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus from neptune_api.types import Response +from neptune_retrieval_api.api.default import search_leaderboard_entries_proto +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO +from neptune_retrieval_api.proto.neptune_pb.api.v1.model.leaderboard_entries_pb2 import ( + ProtoLeaderboardEntriesSearchResultDTO, +) from neptune_scale.exceptions import ( NeptuneConnectionLostError, @@ -129,6 +136,11 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons @abc.abstractmethod def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + @abc.abstractmethod + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: @@ -141,6 +153,9 @@ def __init__(self, api_token: str) -> None: self.backend = create_auth_api_client( credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl ) + # This is required only to silence mypy. The two client objects are compatible, because they're + # generated by swagger codegen. + self.retrieval_backend = cast(neptune_retrieval_api.client.AuthenticatedClient, self.backend) logger.debug("Connected to Neptune API") def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]: @@ -153,6 +168,15 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]), ) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + resp = search_leaderboard_entries_proto.sync_detailed( + client=self.retrieval_backend, project_identifier=project_id, type=["run"], body=body + ) + result = ProtoLeaderboardEntriesSearchResultDTO.FromString(resp.content) + return result + def close(self) -> None: logger.debug("Closing API client") self.backend.__exit__() @@ -181,6 +205,11 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ ) return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={}) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + return ProtoLeaderboardEntriesSearchResultDTO() + def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: if mode == "disabled": diff --git a/src/neptune_scale/net/projects.py b/src/neptune_scale/net/projects.py index ce698ada..4a97de2d 100644 --- a/src/neptune_scale/net/projects.py +++ b/src/neptune_scale/net/projects.py @@ -1,4 +1,3 @@ -import os import re from enum import Enum from json import JSONDecodeError @@ -11,7 +10,6 @@ import httpx from neptune_scale.exceptions import ( - NeptuneApiTokenNotProvided, NeptuneBadRequestError, NeptuneProjectAlreadyExists, ) @@ -19,7 +17,7 @@ HostedApiClient, with_api_errors_handling, ) -from neptune_scale.util.envs import API_TOKEN_ENV_NAME +from neptune_scale.sync.util import ensure_api_token PROJECTS_PATH_BASE = "/api/backend/v1/projects" @@ -33,14 +31,6 @@ class ProjectVisibility(Enum): ORGANIZATION_NOT_FOUND_RE = re.compile(r"Organization .* not found") -def _get_api_token(api_token: Optional[str]) -> str: - api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) - if api_token is None: - raise NeptuneApiTokenNotProvided() - - return api_token - - @with_api_errors_handling def create_project( workspace: str, @@ -52,9 +42,7 @@ def create_project( fail_if_exists: bool = False, api_token: Optional[str] = None, ) -> None: - api_token = _get_api_token(api_token) - - client = HostedApiClient(api_token=api_token) + client = HostedApiClient(api_token=ensure_api_token(api_token)) visibility = ProjectVisibility(visibility) body = { @@ -92,7 +80,7 @@ def _safe_json(response: httpx.Response) -> Any: def get_project_list(*, api_token: Optional[str] = None) -> list[dict]: - client = HostedApiClient(api_token=_get_api_token(api_token)) + client = HostedApiClient(api_token=ensure_api_token(api_token)) params = { "userRelation": "viewerOrHigher", diff --git a/src/neptune_scale/net/runs.py b/src/neptune_scale/net/runs.py new file mode 100644 index 00000000..bf14fe86 --- /dev/null +++ b/src/neptune_scale/net/runs.py @@ -0,0 +1,29 @@ +from typing import Optional + +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO + +from neptune_scale.exceptions import NeptuneScaleError +from neptune_scale.net.api_client import HostedApiClient +from neptune_scale.net.util import escape_nql_criterion +from neptune_scale.sync.util import ensure_api_token + + +def run_exists(project: str, run_id: str, api_token: Optional[str] = None) -> bool: + """Query the backend for the existence of a Run with the given ID. + + Returns True if the Run exists, False otherwise. + """ + + client = HostedApiClient(api_token=ensure_api_token(api_token)) + body = SearchLeaderboardEntriesParamsDTO.from_dict( + { + "query": {"query": f'`sys/custom_run_id`:string = "{escape_nql_criterion(run_id)}"'}, + } + ) + + try: + result = client.search_entries(project, body) + except Exception as e: + raise NeptuneScaleError(reason=e) + + return bool(result.entries) diff --git a/src/neptune_scale/net/util.py b/src/neptune_scale/net/util.py new file mode 100644 index 00000000..cbc25d2e --- /dev/null +++ b/src/neptune_scale/net/util.py @@ -0,0 +1,6 @@ +def escape_nql_criterion(criterion: str) -> str: + """ + Escape backslash and (double-)quotes in the string, to match what the NQL engine expects. + """ + + return criterion.replace("\\", r"\\").replace('"', r"\"") diff --git a/src/neptune_scale/sync/aggregating_queue.py b/src/neptune_scale/sync/aggregating_queue.py index 8e3fc40e..0bf47226 100644 --- a/src/neptune_scale/sync/aggregating_queue.py +++ b/src/neptune_scale/sync/aggregating_queue.py @@ -71,7 +71,7 @@ def commit(self) -> None: def get(self) -> BatchedOperations: start = time.monotonic() - batch_operations: dict[Optional[float], RunOperation] = {} + batch_operations: list[RunOperation] = [] batch_sequence_id: Optional[int] = None batch_timestamp: Optional[float] = None @@ -95,7 +95,7 @@ def get(self) -> BatchedOperations: if not batch_operations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - batch_operations[element.batch_key] = new_operation + batch_operations.append(new_operation) batch_bytes += len(element.operation) else: if not element.is_batchable: @@ -110,10 +110,7 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - if element.batch_key not in batch_operations: - batch_operations[element.batch_key] = new_operation - else: - merge_run_operation(batch_operations[element.batch_key], new_operation) + batch_operations.append(new_operation) batch_bytes += element.metadata_size batch_sequence_id = element.sequence_id @@ -157,54 +154,25 @@ def get(self) -> BatchedOperations: ) -def create_run_batch(operations: dict[Optional[float], RunOperation]) -> RunOperation: +def create_run_batch(operations: list[RunOperation]) -> RunOperation: + if not operations: + raise Empty + if len(operations) == 1: - return next(iter(operations.values())) + return operations[0] - batch = None - for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])): - if batch is None: - batch = RunOperation() - batch.project = operation.project - batch.run_id = operation.run_id - batch.create_missing_project = operation.create_missing_project - batch.api_key = operation.api_key + head = operations[0] + batch = RunOperation() + batch.project = head.project + batch.run_id = head.run_id + batch.create_missing_project = head.create_missing_project + batch.api_key = head.api_key + for operation in operations: operation_type = operation.WhichOneof("operation") if operation_type == "update": batch.update_batch.snapshots.append(operation.update) else: raise ValueError("Cannot batch operation of type %s", operation_type) - if batch is None: - raise Empty return batch - - -def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None: - """ - Merge the `operation` into `batch`, taking into account the special case of `modify_sets`. - - Protobuf merges existing map keys by simply overwriting values, instead of calling - `MergeFrom` on the existing value, eg: A['foo'] = B['foo']. - - We want this instead: - - batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}} - operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}} - - If we called `batch.MergeFrom(operation)` we would get an overwritten value: - result = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - - This function ensures that the `modify_sets` are merged correctly, leaving the default - behaviour for all other fields. - """ - - modify_sets = operation.update.modify_sets - operation.update.ClearField("modify_sets") - - batch.MergeFrom(operation) - - for k, v in modify_sets.items(): - batch.update.modify_sets[k].MergeFrom(v) diff --git a/src/neptune_scale/sync/operations_queue.py b/src/neptune_scale/sync/operations_queue.py index 0069b472..d518da11 100644 --- a/src/neptune_scale/sync/operations_queue.py +++ b/src/neptune_scale/sync/operations_queue.py @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]: with self._lock: return self._last_timestamp - def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None: + def enqueue(self, *, operation: RunOperation, size: Optional[int] = None) -> None: try: is_metadata_update = operation.HasField("update") serialized_operation = operation.SerializeToString() @@ -75,7 +75,6 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O operation=serialized_operation, metadata_size=size, is_batchable=is_metadata_update, - batch_key=key, ), block=True, timeout=None, diff --git a/src/neptune_scale/sync/queue_element.py b/src/neptune_scale/sync/queue_element.py index 1c37ff11..521f89e4 100644 --- a/src/neptune_scale/sync/queue_element.py +++ b/src/neptune_scale/sync/queue_element.py @@ -26,5 +26,3 @@ class SingleOperation(NamedTuple): is_batchable: bool # Size of the metadata in the operation (without project, family, run_id etc.) metadata_size: Optional[int] - # Update metadata key - batch_key: Optional[float] diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index 4329a2aa..61f22053 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -412,6 +412,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: def work(self) -> None: try: + # TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes, + # so we could just pass around instances of RunOperation while (operation := self.get_next()) is not None: sequence_id, timestamp, data = operation diff --git a/src/neptune_scale/sync/util.py b/src/neptune_scale/sync/util.py index 60fe4b0b..2d5ecf96 100644 --- a/src/neptune_scale/sync/util.py +++ b/src/neptune_scale/sync/util.py @@ -1,4 +1,9 @@ +import os import signal +from typing import Optional + +from neptune_scale.exceptions import NeptuneApiTokenNotProvided +from neptune_scale.util.envs import API_TOKEN_ENV_NAME def safe_signal_name(signum: int) -> str: @@ -8,3 +13,13 @@ def safe_signal_name(signum: int) -> str: signame = str(signum) return signame + + +def ensure_api_token(api_token: Optional[str]) -> str: + """Ensure the API token is provided via either explicit argument, or env variable.""" + + api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) + if api_token is None: + raise NeptuneApiTokenNotProvided() + + return api_token diff --git a/tests/e2e/test_net.py b/tests/e2e/test_net.py new file mode 100644 index 00000000..263def20 --- /dev/null +++ b/tests/e2e/test_net.py @@ -0,0 +1,10 @@ +import os + +from neptune_scale.net.runs import run_exists + +NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT") + + +def test_run_exists_true(run): + assert run_exists(run._project, run._run_id) + assert not run_exists(run._project, "nonexistent_run_id") diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py index de9394a0..0f736f57 100644 --- a/tests/unit/test_aggregating_queue.py +++ b/tests/unit/test_aggregating_queue.py @@ -6,6 +6,7 @@ import pytest from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( Step, @@ -32,7 +33,6 @@ def test__simple(): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -60,7 +60,6 @@ def test__max_size_exceeded(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -68,7 +67,6 @@ def test__max_size_exceeded(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -108,7 +106,6 @@ def test__batch_size_limit(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -116,7 +113,6 @@ def test__batch_size_limit(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -148,7 +144,6 @@ def test__batching(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -156,7 +151,6 @@ def test__batching(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -179,7 +173,8 @@ def test__batching(): assert batch.project == "project" assert batch.run_id == "run_id" - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") @@ -189,21 +184,21 @@ def test__queue_element_size_limit_with_different_steps(): update2 = UpdateRunSnapshot(step=Step(whole=2), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) operation1 = RunOperation(update=update1) operation2 = RunOperation(update=update2) + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=2.0, ) # and @@ -235,7 +230,6 @@ def test__not_merge_two_run_creation(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -243,7 +237,6 @@ def test__not_merge_two_run_creation(): operation=operation2.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) # and @@ -301,7 +294,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -309,7 +301,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -351,7 +342,7 @@ def test__not_merge_run_creation_with_metadata_update(): @freeze_time("2024-09-01") -def test__merge_same_key(): +def test__batch_same_key(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -361,21 +352,20 @@ def test__merge_same_key(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp0 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=1.0, ) # and @@ -390,7 +380,7 @@ def test__merge_same_key(): # then assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp + assert result.timestamp == timestamp0 # and batch = RunOperation() @@ -398,12 +388,14 @@ def test__merge_same_key(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update.step == Step(whole=1, micro=0) - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert batch.update_batch.snapshots[0].step == Step(whole=1, micro=0) + assert batch.update_batch.snapshots[1].step == Step(whole=1, micro=0) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") -def test__merge_two_different_steps(): +def test__batch_two_different_steps(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -413,21 +405,21 @@ def test__merge_two_different_steps(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=2.0, ) # and @@ -454,7 +446,7 @@ def test__merge_two_different_steps(): @freeze_time("2024-09-01") -def test__merge_step_with_none(): +def test__batch_step_with_none(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -464,13 +456,13 @@ def test__merge_step_with_none(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -478,7 +470,6 @@ def test__merge_step_with_none(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -501,16 +492,33 @@ def test__merge_step_with_none(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update2, update1] # None is always first + assert batch.update_batch.snapshots == [update1, update2] @freeze_time("2024-09-01") -def test__merge_two_steps_two_metrics(): +def test__batch_two_steps_two_metrics(): # given - update1a = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"aa": Value(int64=10)}) - update2a = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"aa": Value(int64=20)}) - update1b = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"bb": Value(int64=100)}) - update2b = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"bb": Value(int64=200)}) + timestamp0 = int(time.process_time()) + update1a = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 1, nanos=0), + assign={"aa": Value(int64=10)}, + ) + update2a = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 2, nanos=0), + assign={"aa": Value(int64=20)}, + ) + update1b = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 3, nanos=0), + assign={"bb": Value(int64=100)}, + ) + update2b = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 4, nanos=0), + assign={"bb": Value(int64=200)}, + ) # and operations = [ @@ -522,13 +530,12 @@ def test__merge_two_steps_two_metrics(): elements = [ SingleOperation( sequence_id=sequence_id, - timestamp=time.process_time(), + timestamp=timestamp0 + sequence_id, operation=operation.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=batch_key, ) - for sequence_id, batch_key, operation in [ + for sequence_id, step, operation in [ (1, 1.0, operations[0]), (2, 2.0, operations[1]), (3, 1.0, operations[2]), @@ -554,13 +561,6 @@ def test__merge_two_steps_two_metrics(): batch = RunOperation() batch.ParseFromString(result.operation) - update1_merged = UpdateRunSnapshot( - step=Step(whole=1, micro=0), assign={"aa": Value(int64=10), "bb": Value(int64=100)} - ) - update2_merged = UpdateRunSnapshot( - step=Step(whole=2, micro=0), assign={"aa": Value(int64=20), "bb": Value(int64=200)} - ) - assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update1_merged, update2_merged] + assert batch.update_batch.snapshots == [update1a, update2a, update1b, update2b] diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index 481651c7..072aa84a 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from unittest.mock import Mock @@ -8,13 +9,15 @@ ) from neptune_scale.api.attribute import cleanup_path +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.legacy import Run @fixture def run(api_token): run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token) - run._attr_store.log = Mock() + # Mock log to be able to assert calls, but also proxy to the actual method so it does its job + run._attr_store.log = Mock(side_effect=run._attr_store.log) with run: yield run @@ -67,11 +70,30 @@ def test_tags(run, store): def test_series(run, store): - run["sys/series"].append(1, step=1, timestamp=10) - store.log.assert_called_with(metrics={"sys/series": 1}, step=1, timestamp=10) + now = time.time() + run["my/series"].append(1, step=1, timestamp=now) + store.log.assert_called_with(metrics={"my/series": 1}, step=1, timestamp=now) - run["sys/series"].append({"foo": 1, "bar": 2}, step=2) - store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None) + run["my/series"].append({"foo": 1, "bar": 2}, step=2) + store.log.assert_called_with(metrics={"my/series/foo": 1, "my/series/bar": 2}, step=2, timestamp=None) + + +def test_error_on_non_increasing_step(run): + run["series"].append(1, step=2) + + # Step lower than previous + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(2, step=1) + + # Equal to previous, but different value + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(3, step=2) + + # Equal to previous, same value -> should pass + run["series"].append(1, step=2) + + # None should pass, as it means auto-increment + run["series"].append(4, step=None) @pytest.mark.parametrize( diff --git a/tests/unit/test_process_link.py b/tests/unit/test_process_link.py index f3f68ec3..c030a268 100644 --- a/tests/unit/test_process_link.py +++ b/tests/unit/test_process_link.py @@ -169,7 +169,7 @@ def on_closed(_): link.start(on_link_closed=on_closed) # We should never finish the sleep call, as on_closed raises SystemExit - time.sleep(5) + time.sleep(10) assert False, "on_closed callback was not called" @@ -184,5 +184,5 @@ def test_parent_termination(): p = multiprocessing.Process(target=parent, args=(var, event)) p.start() - assert event.wait(1) + assert event.wait(5) assert var.value == 1 diff --git a/tests/unit/test_sync_process.py b/tests/unit/test_sync_process.py index 1709510a..25b17918 100644 --- a/tests/unit/test_sync_process.py +++ b/tests/unit/test_sync_process.py @@ -35,7 +35,6 @@ def single_operation(update: UpdateRunSnapshot, sequence_id): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, )