diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b421be2a..73ef26dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.7.0b + - neptune-api - more-itertools - backoff default_language_version: diff --git a/CHANGELOG.md b/CHANGELOG.md index 556c1cf8..208b7ba4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.8.0] - 2024-11-26 +## 0.9.0 - 2025-01-07 + +### Changes +* Removed support for Python 3.8 (https://github.com/neptune-ai/neptune-client-scale/pull/105) + +### Added +* Added `projects.list_projects()` method to list projects accessible to the current user (https://github.com/neptune-ai/neptune-client-scale/pull/97) + +### Fixed +* Fixed retry behavior on encountering a `NeptuneRetryableError` (https://github.com/neptune-ai/neptune-client-scale/pull/99) +* Fixed batching of metrics when logged with steps out of order (https://github.com/neptune-ai/neptune-client-scale/pull/91) + +### Chores +* Not invoking `on_error_callback` on encountering 408 and 429 HTTP statuses (https://github.com/neptune-ai/neptune-client-scale/pull/110) + +## 0.8.0 - 2024-11-26 ### Added - Added function `neptune_scale.projects.create_project()` to programmatically create Neptune projects ([#92](https://github.com/neptune-ai/neptune-client-scale/pull/92)) @@ -24,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed batching of steps ([#82](https://github.com/neptune-ai/neptune-client-scale/pull/82)) -## [0.7.2] - 2024-11-07 +## 0.7.2 - 2024-11-07 ### Added @@ -32,63 +47,54 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tuple support for tags ([#67](https://github.com/neptune-ai/neptune-client-scale/pull/67)) ### Changed - - Performance improvements - Change the logger's configuration to be more resilient ([#66](https://github.com/neptune-ai/neptune-client-scale/pull/66)) - Update docs: info about timestamp and timezones ([#69](https://github.com/neptune-ai/neptune-client-scale/pull/69)) - Strip quotes from the `NEPTUNE_PROJECT` env variable ([#51](https://github.com/neptune-ai/neptune-client-scale/pull/51)) -## [0.7.1] - 2024-10-28 +## 0.7.1 - 2024-10-28 ### Changed - Removed `family` from run initialization parameters ([#62](https://github.com/neptune-ai/neptune-client-scale/pull/62)) - Made `timestamp` keyword-only in `log_metrics()` ([#58](https://github.com/neptune-ai/neptune-client-scale/pull/58)) -## [0.6.3] - 2024-10-23 +## 0.6.3 - 2024-10-23 ### Changed - - Changed the signature of `Run.log_metrics`: - `date` is now the first parameter in line with other logging methods ([#58](https://github.com/neptune-ai/neptune-client-scale/pull/58)) - `step` and `data` are now mandatory ([#55](https://github.com/neptune-ai/neptune-client-scale/pull/55)) - - Removed iterables from `log_config` value type hints ([#53](https://github.com/neptune-ai/neptune-client-scale/pull/53)) -## [0.6.0] - 2024-09-09 +## 0.6.0 - 2024-09-09 ### Added - - Dedicated exceptions for missing project or API token ([#44](https://github.com/neptune-ai/neptune-client-scale/pull/44)) ### Changed - - Removed `timestamp` parameter from `add_tags()`, `remove_tags()` and `log_configs()` methods ([#37](https://github.com/neptune-ai/neptune-client-scale/pull/37)) - Performance improvements of metadata logging ([#42](https://github.com/neptune-ai/neptune-client-scale/pull/42)) ## [0.5.0] - 2024-09-05 ### Added - - Added docstrings to logging methods ([#40](https://github.com/neptune-ai/neptune-client-scale/pull/40)) ## [0.4.0] - 2024-09-03 ### Added - - Added support for integer values when logging metric values ([#33](https://github.com/neptune-ai/neptune-client-scale/pull/33)) - Added support for async lag threshold ([#22](https://github.com/neptune-ai/neptune-client-scale/pull/22)) ## [0.3.0] - 2024-09-03 ### Added - - Package renamed to `neptune-scale` ([#31](https://github.com/neptune-ai/neptune-client-scale/pull/31)) ## [0.2.0] - 2024-09-02 ### Added - - Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) - Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) - Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8)) diff --git a/pyproject.toml b/pyproject.toml index f2953e94..b2e285cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ pattern = "default-unprefixed" [tool.poetry.dependencies] python = "^3.9" -neptune-api = "^0.9.0" +neptune-api = "^0.10.0" more-itertools = "^10.0.0" psutil = "^5.0.0" backoff = "^2.0.0" @@ -19,7 +19,7 @@ backoff = "^2.0.0" [tool.poetry] name = "neptune-scale" version = "0.1.0" -description = "A minimal client library" +description = "Python logging API for Neptune Scale" authors = ["neptune.ai "] repository = "https://github.com/neptune-ai/neptune-client-scale" readme = "README.md" diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index ea433137..e3dbbf7d 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -1,5 +1,6 @@ import functools import itertools +import threading import warnings from collections.abc import ( Collection, @@ -14,6 +15,7 @@ cast, ) +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) self._run_id = run_id self._operations_queue = operations_queue self._attributes: dict[str, Attribute] = {} + # Keep a list of path -> (last step, last value) mappings to detect non-increasing steps + # at call site. The backend will detect this error as well, but it's more convenient for the user + # to get the error as soon as possible. + self._metric_state: dict[str, tuple[float, float]] = {} + self._lock = threading.RLock() def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -85,22 +92,45 @@ def log( ) -> None: if timestamp is None: timestamp = datetime.now() - elif isinstance(timestamp, float): + elif isinstance(timestamp, (float, int)): timestamp = datetime.fromtimestamp(timestamp) - splitter: MetadataSplitter = MetadataSplitter( - project=self._project, - run_id=self._run_id, - step=step, - timestamp=timestamp, - configs=configs, - metrics=metrics, - add_tags=tags_add, - remove_tags=tags_remove, + # MetadataSplitter is an iterator, so gather everything into a list instead of iterating over + # it in the critical section, to avoid holding the lock for too long. + # TODO: Move splitting into the worker process. Here we should just send messages as they are. + chunks = list( + MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + configs=configs, + metrics=metrics, + add_tags=tags_add, + remove_tags=tags_remove, + ) ) - for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + with self._lock: + self._verify_and_update_metrics_state(step, metrics) + + for operation, metadata_size in chunks: + self._operations_queue.enqueue(operation=operation, size=metadata_size) + + def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[dict[str, float]]) -> None: + """Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not.""" + + if step is None or metrics is None: + return + + for metric, value in metrics.items(): + if (state := self._metric_state.get(metric)) is not None: + last_step, last_value = state + # Repeating a step is fine as long as the value does not change + if step < last_step or (step == last_step and value != last_value): + raise NeptuneSeriesStepNonIncreasing() + + self._metric_state[metric] = (step, value) class Attribute: diff --git a/src/neptune_scale/exceptions.py b/src/neptune_scale/exceptions.py index f1e6aa7e..61510384 100644 --- a/src/neptune_scale/exceptions.py +++ b/src/neptune_scale/exceptions.py @@ -36,6 +36,7 @@ "NeptuneAsyncLagThresholdExceeded", "NeptuneProjectNotProvided", "NeptuneApiTokenNotProvided", + "NeptuneTooManyRequestsResponseError", ) from typing import Any @@ -191,6 +192,15 @@ class NeptuneUnexpectedResponseError(NeptuneRetryableError): """ +class NeptuneTooManyRequestsResponseError(NeptuneRetryableError): + message = """ +{h1} +NeptuneTooManyRequestsResponseError: The Neptune server reported receiving too many requests. +{end} +This is a temporary problem. If the problem persists, please contact us at support@neptune.ai. +""" + + class NeptuneInternalServerError(NeptuneRetryableError): message = """ {h1} diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index a6cf90c2..1bbcccfd 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -27,9 +27,11 @@ from typing import ( Any, Literal, + cast, ) import httpx +import neptune_retrieval_api.client from httpx import Timeout from neptune_api import ( AuthenticatedClient, @@ -64,6 +66,11 @@ from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus from neptune_api.types import Response +from neptune_retrieval_api.api.default import search_leaderboard_entries_proto +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO +from neptune_retrieval_api.proto.neptune_pb.api.v1.model.leaderboard_entries_pb2 import ( + ProtoLeaderboardEntriesSearchResultDTO, +) from neptune_scale.exceptions import ( NeptuneConnectionLostError, @@ -129,6 +136,11 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons @abc.abstractmethod def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + @abc.abstractmethod + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: @@ -141,6 +153,9 @@ def __init__(self, api_token: str) -> None: self.backend = create_auth_api_client( credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl ) + # This is required only to silence mypy. The two client objects are compatible, because they're + # generated by swagger codegen. + self.retrieval_backend = cast(neptune_retrieval_api.client.AuthenticatedClient, self.backend) logger.debug("Connected to Neptune API") def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]: @@ -153,6 +168,15 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]), ) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + resp = search_leaderboard_entries_proto.sync_detailed( + client=self.retrieval_backend, project_identifier=project_id, type=["run"], body=body + ) + result = ProtoLeaderboardEntriesSearchResultDTO.FromString(resp.content) + return result + def close(self) -> None: logger.debug("Closing API client") self.backend.__exit__() @@ -181,6 +205,11 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ ) return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={}) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + return ProtoLeaderboardEntriesSearchResultDTO() + def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: if mode == "disabled": diff --git a/src/neptune_scale/net/projects.py b/src/neptune_scale/net/projects.py index ce698ada..4a97de2d 100644 --- a/src/neptune_scale/net/projects.py +++ b/src/neptune_scale/net/projects.py @@ -1,4 +1,3 @@ -import os import re from enum import Enum from json import JSONDecodeError @@ -11,7 +10,6 @@ import httpx from neptune_scale.exceptions import ( - NeptuneApiTokenNotProvided, NeptuneBadRequestError, NeptuneProjectAlreadyExists, ) @@ -19,7 +17,7 @@ HostedApiClient, with_api_errors_handling, ) -from neptune_scale.util.envs import API_TOKEN_ENV_NAME +from neptune_scale.sync.util import ensure_api_token PROJECTS_PATH_BASE = "/api/backend/v1/projects" @@ -33,14 +31,6 @@ class ProjectVisibility(Enum): ORGANIZATION_NOT_FOUND_RE = re.compile(r"Organization .* not found") -def _get_api_token(api_token: Optional[str]) -> str: - api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) - if api_token is None: - raise NeptuneApiTokenNotProvided() - - return api_token - - @with_api_errors_handling def create_project( workspace: str, @@ -52,9 +42,7 @@ def create_project( fail_if_exists: bool = False, api_token: Optional[str] = None, ) -> None: - api_token = _get_api_token(api_token) - - client = HostedApiClient(api_token=api_token) + client = HostedApiClient(api_token=ensure_api_token(api_token)) visibility = ProjectVisibility(visibility) body = { @@ -92,7 +80,7 @@ def _safe_json(response: httpx.Response) -> Any: def get_project_list(*, api_token: Optional[str] = None) -> list[dict]: - client = HostedApiClient(api_token=_get_api_token(api_token)) + client = HostedApiClient(api_token=ensure_api_token(api_token)) params = { "userRelation": "viewerOrHigher", diff --git a/src/neptune_scale/net/runs.py b/src/neptune_scale/net/runs.py new file mode 100644 index 00000000..bf14fe86 --- /dev/null +++ b/src/neptune_scale/net/runs.py @@ -0,0 +1,29 @@ +from typing import Optional + +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO + +from neptune_scale.exceptions import NeptuneScaleError +from neptune_scale.net.api_client import HostedApiClient +from neptune_scale.net.util import escape_nql_criterion +from neptune_scale.sync.util import ensure_api_token + + +def run_exists(project: str, run_id: str, api_token: Optional[str] = None) -> bool: + """Query the backend for the existence of a Run with the given ID. + + Returns True if the Run exists, False otherwise. + """ + + client = HostedApiClient(api_token=ensure_api_token(api_token)) + body = SearchLeaderboardEntriesParamsDTO.from_dict( + { + "query": {"query": f'`sys/custom_run_id`:string = "{escape_nql_criterion(run_id)}"'}, + } + ) + + try: + result = client.search_entries(project, body) + except Exception as e: + raise NeptuneScaleError(reason=e) + + return bool(result.entries) diff --git a/src/neptune_scale/net/util.py b/src/neptune_scale/net/util.py new file mode 100644 index 00000000..cbc25d2e --- /dev/null +++ b/src/neptune_scale/net/util.py @@ -0,0 +1,6 @@ +def escape_nql_criterion(criterion: str) -> str: + """ + Escape backslash and (double-)quotes in the string, to match what the NQL engine expects. + """ + + return criterion.replace("\\", r"\\").replace('"', r"\"") diff --git a/src/neptune_scale/sync/aggregating_queue.py b/src/neptune_scale/sync/aggregating_queue.py index 8e3fc40e..0bf47226 100644 --- a/src/neptune_scale/sync/aggregating_queue.py +++ b/src/neptune_scale/sync/aggregating_queue.py @@ -71,7 +71,7 @@ def commit(self) -> None: def get(self) -> BatchedOperations: start = time.monotonic() - batch_operations: dict[Optional[float], RunOperation] = {} + batch_operations: list[RunOperation] = [] batch_sequence_id: Optional[int] = None batch_timestamp: Optional[float] = None @@ -95,7 +95,7 @@ def get(self) -> BatchedOperations: if not batch_operations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - batch_operations[element.batch_key] = new_operation + batch_operations.append(new_operation) batch_bytes += len(element.operation) else: if not element.is_batchable: @@ -110,10 +110,7 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - if element.batch_key not in batch_operations: - batch_operations[element.batch_key] = new_operation - else: - merge_run_operation(batch_operations[element.batch_key], new_operation) + batch_operations.append(new_operation) batch_bytes += element.metadata_size batch_sequence_id = element.sequence_id @@ -157,54 +154,25 @@ def get(self) -> BatchedOperations: ) -def create_run_batch(operations: dict[Optional[float], RunOperation]) -> RunOperation: +def create_run_batch(operations: list[RunOperation]) -> RunOperation: + if not operations: + raise Empty + if len(operations) == 1: - return next(iter(operations.values())) + return operations[0] - batch = None - for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])): - if batch is None: - batch = RunOperation() - batch.project = operation.project - batch.run_id = operation.run_id - batch.create_missing_project = operation.create_missing_project - batch.api_key = operation.api_key + head = operations[0] + batch = RunOperation() + batch.project = head.project + batch.run_id = head.run_id + batch.create_missing_project = head.create_missing_project + batch.api_key = head.api_key + for operation in operations: operation_type = operation.WhichOneof("operation") if operation_type == "update": batch.update_batch.snapshots.append(operation.update) else: raise ValueError("Cannot batch operation of type %s", operation_type) - if batch is None: - raise Empty return batch - - -def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None: - """ - Merge the `operation` into `batch`, taking into account the special case of `modify_sets`. - - Protobuf merges existing map keys by simply overwriting values, instead of calling - `MergeFrom` on the existing value, eg: A['foo'] = B['foo']. - - We want this instead: - - batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}} - operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}} - - If we called `batch.MergeFrom(operation)` we would get an overwritten value: - result = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - - This function ensures that the `modify_sets` are merged correctly, leaving the default - behaviour for all other fields. - """ - - modify_sets = operation.update.modify_sets - operation.update.ClearField("modify_sets") - - batch.MergeFrom(operation) - - for k, v in modify_sets.items(): - batch.update.modify_sets[k].MergeFrom(v) diff --git a/src/neptune_scale/sync/errors_tracking.py b/src/neptune_scale/sync/errors_tracking.py index fbf48b88..c7f51bb0 100644 --- a/src/neptune_scale/sync/errors_tracking.py +++ b/src/neptune_scale/sync/errors_tracking.py @@ -12,8 +12,10 @@ NeptuneAsyncLagThresholdExceeded, NeptuneConnectionLostError, NeptuneOperationsQueueMaxSizeExceeded, + NeptuneRetryableError, NeptuneScaleError, NeptuneScaleWarning, + NeptuneTooManyRequestsResponseError, NeptuneUnexpectedError, ) from neptune_scale.sync.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME @@ -109,6 +111,10 @@ def work(self) -> None: self._on_async_lag_callback() elif isinstance(error, NeptuneScaleWarning): self._on_warning_callback(error, last_raised_at) + elif isinstance(error, NeptuneTooManyRequestsResponseError): + self._on_warning_callback(error, last_raised_at) + elif isinstance(error, NeptuneRetryableError): + self._on_warning_callback(error, last_raised_at) elif isinstance(error, NeptuneScaleError): self._on_error_callback(error, last_raised_at) else: diff --git a/src/neptune_scale/sync/operations_queue.py b/src/neptune_scale/sync/operations_queue.py index 0069b472..d518da11 100644 --- a/src/neptune_scale/sync/operations_queue.py +++ b/src/neptune_scale/sync/operations_queue.py @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]: with self._lock: return self._last_timestamp - def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None: + def enqueue(self, *, operation: RunOperation, size: Optional[int] = None) -> None: try: is_metadata_update = operation.HasField("update") serialized_operation = operation.SerializeToString() @@ -75,7 +75,6 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O operation=serialized_operation, metadata_size=size, is_batchable=is_metadata_update, - batch_key=key, ), block=True, timeout=None, diff --git a/src/neptune_scale/sync/queue_element.py b/src/neptune_scale/sync/queue_element.py index 1c37ff11..521f89e4 100644 --- a/src/neptune_scale/sync/queue_element.py +++ b/src/neptune_scale/sync/queue_element.py @@ -26,5 +26,3 @@ class SingleOperation(NamedTuple): is_batchable: bool # Size of the metadata in the operation (without project, family, run_id etc.) metadata_size: Optional[int] - # Update metadata key - batch_key: Optional[float] diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index b1d4f567..425754a2 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -6,6 +6,7 @@ import queue import signal import threading +from collections.abc import Iterable from multiprocessing import ( Process, Queue, @@ -16,6 +17,7 @@ Literal, NamedTuple, Optional, + Protocol, TypeVar, ) @@ -37,7 +39,6 @@ NeptuneAttributeTypeUnsupported, NeptuneFloatValueNanInfUnsupported, NeptuneInternalServerError, - NeptuneOperationsQueueMaxSizeExceeded, NeptuneProjectInvalidName, NeptuneProjectNotFound, NeptuneRetryableError, @@ -53,6 +54,7 @@ NeptuneStringSetExceedsSizeLimit, NeptuneStringValueExceedsSizeLimit, NeptuneSynchronizationStopped, + NeptuneTooManyRequestsResponseError, NeptuneUnauthorizedError, NeptuneUnexpectedError, NeptuneUnexpectedResponseError, @@ -177,7 +179,7 @@ def __init__( ) -> None: super().__init__(name="SyncProcess") - self._external_operations_queue: Queue[SingleOperation] = operations_queue + self._input_operations_queue: Queue[SingleOperation] = operations_queue self._errors_queue: ErrorsQueue = errors_queue self._process_link: ProcessLink = process_link self._api_token: str = api_token @@ -211,7 +213,7 @@ def run(self) -> None: family=self._family, api_token=self._api_token, errors_queue=self._errors_queue, - external_operations_queue=self._external_operations_queue, + input_queue=self._input_operations_queue, last_queued_seq=self._last_queued_seq, last_ack_seq=self._last_ack_seq, max_queue_size=self._max_queue_size, @@ -233,6 +235,10 @@ def run(self) -> None: logger.info("Data synchronization finished") +class SupportsPutNowait(Protocol): + def put_nowait(self, element: SingleOperation) -> None: ... + + class SyncProcessWorker(WithResources): def __init__( self, @@ -242,7 +248,7 @@ def __init__( family: str, mode: Literal["async", "disabled"], errors_queue: ErrorsQueue, - external_operations_queue: multiprocessing.Queue[SingleOperation], + input_queue: multiprocessing.Queue[SingleOperation], last_queued_seq: SharedInt, last_ack_seq: SharedInt, last_ack_timestamp: SharedFloat, @@ -261,11 +267,7 @@ def __init__( last_queued_seq=last_queued_seq, mode=mode, ) - self._external_to_internal_thread = InternalQueueFeederThread( - external=external_operations_queue, - internal=self._internal_operations_queue, - errors_queue=self._errors_queue, - ) + self._status_tracking_thread = StatusTrackingThread( api_token=api_token, mode=mode, @@ -276,13 +278,19 @@ def __init__( last_ack_timestamp=last_ack_timestamp, ) + self._operation_dispatcher_thread = OperationDispatcherThread( + input_queue=input_queue, + consumers=[self._internal_operations_queue], + errors_queue=self._errors_queue, + ) + @property def threads(self) -> tuple[Daemon, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread + return self._operation_dispatcher_thread, self._sync_thread, self._status_tracking_thread @property def resources(self) -> tuple[Resource, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread + return self._operation_dispatcher_thread, self._sync_thread, self._status_tracking_thread def interrupt(self) -> None: for thread in self.threads: @@ -302,17 +310,23 @@ def join(self, timeout: Optional[int] = None) -> None: thread.join(timeout=timeout) -class InternalQueueFeederThread(Daemon, Resource): +class OperationDispatcherThread(Daemon, Resource): + """Reads incoming messages from a multiprocessing.Queue, and dispatches them to a list of consumers, + which can be of type `queue.Queue`, but also any other object that supports put_nowait() method. + + If any of the consumers' put_nowait() raises queue.Full, the thread will stop processing further operations. + """ + def __init__( self, - external: multiprocessing.Queue[SingleOperation], - internal: AggregatingQueue, + input_queue: multiprocessing.Queue[SingleOperation], + consumers: Iterable[SupportsPutNowait], errors_queue: ErrorsQueue, ) -> None: - super().__init__(name="InternalQueueFeederThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) + super().__init__(name="OperationDispatcherThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) - self._external: multiprocessing.Queue[SingleOperation] = external - self._internal: AggregatingQueue = internal + self._input_queue: multiprocessing.Queue[SingleOperation] = input_queue + self._consumers = tuple(consumers) self._errors_queue: ErrorsQueue = errors_queue self._latest_unprocessed: Optional[SingleOperation] = None @@ -322,7 +336,7 @@ def get_next(self) -> Optional[SingleOperation]: return self._latest_unprocessed try: - self._latest_unprocessed = self._external.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) + self._latest_unprocessed = self._input_queue.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) return self._latest_unprocessed except queue.Empty: return None @@ -333,18 +347,22 @@ def commit(self) -> None: def work(self) -> None: try: while not self._is_interrupted(): - operation = self.get_next() - if operation is None: + if (operation := self.get_next()) is None: continue try: - self._internal.put_nowait(operation) + for consumer in self._consumers: + consumer.put_nowait(operation) self.commit() - except queue.Full: - logger.debug("Internal queue is full (%d elements), waiting for free space", self._internal.maxsize) - self._errors_queue.put(NeptuneOperationsQueueMaxSizeExceeded(max_size=self._internal.maxsize)) - # Sleep before retry - break + except queue.Full as e: + # We have two ways to deal with this situation: + # 1. Consider this a fatal error, and stop processing further operations. + # 2. Retry, assuming that any consumer that _did_ manage to receive the operation, is + # idempotent and can handle the same operation again. + # + # Currently, we choose 1. + logger.error("Operation queue overflow. Neptune will not process further operations.") + raise e except Exception as e: self._errors_queue.put(e) self.interrupt() @@ -402,20 +420,16 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: response = self._backend.submit(operation=operation, family=self._family) - if response.status_code == 403: - raise NeptuneUnauthorizedError() - - if response.status_code != 200: - logger.error("HTTP response error: %s", response.status_code) - if response.status_code // 100 == 5: - raise NeptuneInternalServerError() - else: - raise NeptuneUnexpectedResponseError() + status_code = response.status_code + if status_code != 200: + _raise_exception(status_code) return response.parsed def work(self) -> None: try: + # TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes, + # so we could just pass around instances of RunOperation while (operation := self.get_next()) is not None: sequence_id, timestamp, data = operation @@ -452,6 +466,20 @@ def work(self) -> None: raise NeptuneSynchronizationStopped() from e +def _raise_exception(status_code: int) -> None: + logger.error("HTTP response error: %s", status_code) + if status_code == 403: + raise NeptuneUnauthorizedError() + elif status_code == 408: + raise NeptuneConnectionLostError() + elif status_code == 429: + raise NeptuneTooManyRequestsResponseError() + elif status_code // 100 == 5: + raise NeptuneInternalServerError() + else: + raise NeptuneUnexpectedResponseError() + + class StatusTrackingThread(Daemon, WithResources): def __init__( self, @@ -495,15 +523,10 @@ def check_batch(self, *, request_ids: list[str]) -> Optional[BulkRequestStatus]: response = self._backend.check_batch(request_ids=request_ids, project=self._project) - if response.status_code == 403: - raise NeptuneUnauthorizedError() + status_code = response.status_code - if response.status_code != 200: - logger.error("HTTP response error: %s", response.status_code) - if response.status_code // 100 == 5: - raise NeptuneInternalServerError() - else: - raise NeptuneUnexpectedResponseError() + if status_code != 200: + _raise_exception(status_code) return response.parsed diff --git a/src/neptune_scale/sync/util.py b/src/neptune_scale/sync/util.py index 60fe4b0b..2d5ecf96 100644 --- a/src/neptune_scale/sync/util.py +++ b/src/neptune_scale/sync/util.py @@ -1,4 +1,9 @@ +import os import signal +from typing import Optional + +from neptune_scale.exceptions import NeptuneApiTokenNotProvided +from neptune_scale.util.envs import API_TOKEN_ENV_NAME def safe_signal_name(signum: int) -> str: @@ -8,3 +13,13 @@ def safe_signal_name(signum: int) -> str: signame = str(signum) return signame + + +def ensure_api_token(api_token: Optional[str]) -> str: + """Ensure the API token is provided via either explicit argument, or env variable.""" + + api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) + if api_token is None: + raise NeptuneApiTokenNotProvided() + + return api_token diff --git a/tests/e2e/test_net.py b/tests/e2e/test_net.py new file mode 100644 index 00000000..263def20 --- /dev/null +++ b/tests/e2e/test_net.py @@ -0,0 +1,10 @@ +import os + +from neptune_scale.net.runs import run_exists + +NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT") + + +def test_run_exists_true(run): + assert run_exists(run._project, run._run_id) + assert not run_exists(run._project, "nonexistent_run_id") diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py index de9394a0..0f736f57 100644 --- a/tests/unit/test_aggregating_queue.py +++ b/tests/unit/test_aggregating_queue.py @@ -6,6 +6,7 @@ import pytest from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( Step, @@ -32,7 +33,6 @@ def test__simple(): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -60,7 +60,6 @@ def test__max_size_exceeded(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -68,7 +67,6 @@ def test__max_size_exceeded(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -108,7 +106,6 @@ def test__batch_size_limit(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -116,7 +113,6 @@ def test__batch_size_limit(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -148,7 +144,6 @@ def test__batching(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -156,7 +151,6 @@ def test__batching(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, ) # and @@ -179,7 +173,8 @@ def test__batching(): assert batch.project == "project" assert batch.run_id == "run_id" - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") @@ -189,21 +184,21 @@ def test__queue_element_size_limit_with_different_steps(): update2 = UpdateRunSnapshot(step=Step(whole=2), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) operation1 = RunOperation(update=update1) operation2 = RunOperation(update=update2) + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=2.0, ) # and @@ -235,7 +230,6 @@ def test__not_merge_two_run_creation(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -243,7 +237,6 @@ def test__not_merge_two_run_creation(): operation=operation2.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) # and @@ -301,7 +294,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, ) element2 = SingleOperation( sequence_id=2, @@ -309,7 +301,6 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) # and @@ -351,7 +342,7 @@ def test__not_merge_run_creation_with_metadata_update(): @freeze_time("2024-09-01") -def test__merge_same_key(): +def test__batch_same_key(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -361,21 +352,20 @@ def test__merge_same_key(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp0 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp0, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=1.0, ) # and @@ -390,7 +380,7 @@ def test__merge_same_key(): # then assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp + assert result.timestamp == timestamp0 # and batch = RunOperation() @@ -398,12 +388,14 @@ def test__merge_same_key(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update.step == Step(whole=1, micro=0) - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) + assert batch.update_batch.snapshots[0].step == Step(whole=1, micro=0) + assert batch.update_batch.snapshots[1].step == Step(whole=1, micro=0) + assert all(k in batch.update_batch.snapshots[0].assign for k in ["aa0", "aa1"]) + assert all(k in batch.update_batch.snapshots[1].assign for k in ["bb0", "bb1"]) @freeze_time("2024-09-01") -def test__merge_two_different_steps(): +def test__batch_two_different_steps(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -413,21 +405,21 @@ def test__merge_two_different_steps(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() + timestamp2 = timestamp1 + 1 element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=timestamp2, operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=2.0, ) # and @@ -454,7 +446,7 @@ def test__merge_two_different_steps(): @freeze_time("2024-09-01") -def test__merge_step_with_none(): +def test__batch_step_with_none(): # given update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -464,13 +456,13 @@ def test__merge_step_with_none(): operation2 = RunOperation(update=update2, project="project", run_id="run_id") # and + timestamp1 = time.process_time() element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=timestamp1, operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -478,7 +470,6 @@ def test__merge_step_with_none(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, ) # and @@ -501,16 +492,33 @@ def test__merge_step_with_none(): assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update2, update1] # None is always first + assert batch.update_batch.snapshots == [update1, update2] @freeze_time("2024-09-01") -def test__merge_two_steps_two_metrics(): +def test__batch_two_steps_two_metrics(): # given - update1a = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"aa": Value(int64=10)}) - update2a = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"aa": Value(int64=20)}) - update1b = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"bb": Value(int64=100)}) - update2b = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"bb": Value(int64=200)}) + timestamp0 = int(time.process_time()) + update1a = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 1, nanos=0), + assign={"aa": Value(int64=10)}, + ) + update2a = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 2, nanos=0), + assign={"aa": Value(int64=20)}, + ) + update1b = UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 3, nanos=0), + assign={"bb": Value(int64=100)}, + ) + update2b = UpdateRunSnapshot( + step=Step(whole=2, micro=0), + timestamp=Timestamp(seconds=timestamp0 + 4, nanos=0), + assign={"bb": Value(int64=200)}, + ) # and operations = [ @@ -522,13 +530,12 @@ def test__merge_two_steps_two_metrics(): elements = [ SingleOperation( sequence_id=sequence_id, - timestamp=time.process_time(), + timestamp=timestamp0 + sequence_id, operation=operation.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=batch_key, ) - for sequence_id, batch_key, operation in [ + for sequence_id, step, operation in [ (1, 1.0, operations[0]), (2, 2.0, operations[1]), (3, 1.0, operations[2]), @@ -554,13 +561,6 @@ def test__merge_two_steps_two_metrics(): batch = RunOperation() batch.ParseFromString(result.operation) - update1_merged = UpdateRunSnapshot( - step=Step(whole=1, micro=0), assign={"aa": Value(int64=10), "bb": Value(int64=100)} - ) - update2_merged = UpdateRunSnapshot( - step=Step(whole=2, micro=0), assign={"aa": Value(int64=20), "bb": Value(int64=200)} - ) - assert batch.project == "project" assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update1_merged, update2_merged] + assert batch.update_batch.snapshots == [update1a, update2a, update1b, update2b] diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index 481651c7..072aa84a 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from unittest.mock import Mock @@ -8,13 +9,15 @@ ) from neptune_scale.api.attribute import cleanup_path +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.legacy import Run @fixture def run(api_token): run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token) - run._attr_store.log = Mock() + # Mock log to be able to assert calls, but also proxy to the actual method so it does its job + run._attr_store.log = Mock(side_effect=run._attr_store.log) with run: yield run @@ -67,11 +70,30 @@ def test_tags(run, store): def test_series(run, store): - run["sys/series"].append(1, step=1, timestamp=10) - store.log.assert_called_with(metrics={"sys/series": 1}, step=1, timestamp=10) + now = time.time() + run["my/series"].append(1, step=1, timestamp=now) + store.log.assert_called_with(metrics={"my/series": 1}, step=1, timestamp=now) - run["sys/series"].append({"foo": 1, "bar": 2}, step=2) - store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None) + run["my/series"].append({"foo": 1, "bar": 2}, step=2) + store.log.assert_called_with(metrics={"my/series/foo": 1, "my/series/bar": 2}, step=2, timestamp=None) + + +def test_error_on_non_increasing_step(run): + run["series"].append(1, step=2) + + # Step lower than previous + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(2, step=1) + + # Equal to previous, but different value + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(3, step=2) + + # Equal to previous, same value -> should pass + run["series"].append(1, step=2) + + # None should pass, as it means auto-increment + run["series"].append(4, step=None) @pytest.mark.parametrize( diff --git a/tests/unit/test_errors_monitor.py b/tests/unit/test_errors_monitor.py index ba2bc872..fe11b78e 100644 --- a/tests/unit/test_errors_monitor.py +++ b/tests/unit/test_errors_monitor.py @@ -11,6 +11,7 @@ NeptuneScaleError, NeptuneScaleWarning, NeptuneSeriesPointDuplicate, + NeptuneTooManyRequestsResponseError, ) from neptune_scale.sync.errors_tracking import ( ErrorsMonitor, @@ -22,13 +23,14 @@ ["error", "callback_name"], [ (NeptuneScaleError("error1"), "on_error_callback"), - (NeptuneRetryableError("error1"), "on_error_callback"), + (NeptuneRetryableError("error1"), "on_warning_callback"), (ValueError("error2"), "on_error_callback"), (NeptuneScaleWarning("error3"), "on_warning_callback"), (NeptuneSeriesPointDuplicate("error4"), "on_warning_callback"), (NeptuneOperationsQueueMaxSizeExceeded("error5"), "on_queue_full_callback"), (NeptuneConnectionLostError("error6"), "on_network_error_callback"), (NeptuneAsyncLagThresholdExceeded("error7"), "on_async_lag_callback"), + (NeptuneTooManyRequestsResponseError(), "on_warning_callback"), ], ) def test_errors_monitor_callbacks_called(error, callback_name): diff --git a/tests/unit/test_process_link.py b/tests/unit/test_process_link.py index f3f68ec3..c030a268 100644 --- a/tests/unit/test_process_link.py +++ b/tests/unit/test_process_link.py @@ -169,7 +169,7 @@ def on_closed(_): link.start(on_link_closed=on_closed) # We should never finish the sleep call, as on_closed raises SystemExit - time.sleep(5) + time.sleep(10) assert False, "on_closed callback was not called" @@ -184,5 +184,5 @@ def test_parent_termination(): p = multiprocessing.Process(target=parent, args=(var, event)) p.start() - assert event.wait(1) + assert event.wait(5) assert var.value == 1 diff --git a/tests/unit/test_sync_process.py b/tests/unit/test_sync_process.py index 623a4bb5..5f954947 100644 --- a/tests/unit/test_sync_process.py +++ b/tests/unit/test_sync_process.py @@ -1,5 +1,7 @@ import queue import time +from dataclasses import dataclass +from typing import Any from unittest.mock import Mock import pytest @@ -58,12 +60,21 @@ def single_operation(update: UpdateRunSnapshot, sequence_id): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, ) -def test_sender_thread_work_finishes_when_queue_empty(): - # given +@dataclass +class MockedSender: + operations_queue: Any + status_tracking_queue: Any + errors_queue: Any + last_queue_seq: Any + backend: Any + sender_thread: Any + + +@pytest.fixture +def sender() -> MockedSender: operations_queue = Mock() status_tracking_queue = Mock() errors_queue = Mock() @@ -80,128 +91,104 @@ def test_sender_thread_work_finishes_when_queue_empty(): ) sender_thread._backend = backend - # and - operations_queue.get.side_effect = queue.Empty + return MockedSender(operations_queue, status_tracking_queue, errors_queue, last_queue_seq, backend, sender_thread) + + +def test_sender_thread_work_finishes_when_queue_empty(sender): + # given + sender.operations_queue.get.side_effect = queue.Empty # when - sender_thread.work() + sender.sender_thread.work() # then assert True -def test_sender_thread_processes_single_element(): +def test_sender_thread_processes_single_element(sender): # given - operations_queue = Mock() - status_tracking_queue = Mock() - errors_queue = Mock() - last_queue_seq = SharedInt(initial_value=0) - backend = Mock() - sender_thread = SenderThread( - api_token="", - family="", - operations_queue=operations_queue, - status_tracking_queue=status_tracking_queue, - errors_queue=errors_queue, - last_queued_seq=last_queue_seq, - mode="disabled", - ) - sender_thread._backend = backend - - # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ + sender.operations_queue.get.side_effect = [ BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), queue.Empty, ] # and - backend.submit.side_effect = [submit_response(["1"])] + sender.backend.submit.side_effect = [submit_response(["1"])] # when - sender_thread.work() + sender.sender_thread.work() # then - assert backend.submit.call_count == 1 + assert sender.backend.submit.call_count == 1 -def test_sender_thread_processes_element_on_single_retryable_error(): +def test_sender_thread_processes_element_on_single_retryable_error(sender): # given - operations_queue = Mock() - status_tracking_queue = Mock() - errors_queue = Mock() - last_queue_seq = SharedInt(initial_value=0) - backend = Mock() - sender_thread = SenderThread( - api_token="", - family="", - operations_queue=operations_queue, - status_tracking_queue=status_tracking_queue, - errors_queue=errors_queue, - last_queued_seq=last_queue_seq, - mode="disabled", - ) - sender_thread._backend = backend - - # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ + sender.operations_queue.get.side_effect = [ BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), queue.Empty, ] # and - backend.submit.side_effect = [ + sender.backend.submit.side_effect = [ submit_response([], status_code=503), submit_response(["a"], status_code=200), ] # when - sender_thread.work() + sender.sender_thread.work() # then - assert backend.submit.call_count == 2 + assert sender.backend.submit.call_count == 2 -def test_sender_thread_fails_on_regular_error(): +def test_sender_thread_fails_on_regular_error(sender): # given - operations_queue = Mock() - status_tracking_queue = Mock() - errors_queue = Mock() - last_queue_seq = SharedInt(initial_value=0) - backend = Mock() - sender_thread = SenderThread( - api_token="", - family="", - operations_queue=operations_queue, - status_tracking_queue=status_tracking_queue, - errors_queue=errors_queue, - last_queued_seq=last_queue_seq, - mode="disabled", - ) - sender_thread._backend = backend - - # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ + sender.operations_queue.get.side_effect = [ BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), queue.Empty, ] # and - backend.submit.side_effect = [ + sender.backend.submit.side_effect = [ submit_response([], status_code=200), ] # when with pytest.raises(NeptuneSynchronizationStopped): - sender_thread.work() + sender.sender_thread.work() # then should throw NeptuneInternalServerError - errors_queue.put.assert_called_once() + sender.errors_queue.put.assert_called_once() + + +def test_sender_thread_processes_element_on_429_and_408_http_statuses(sender): + # given + update = UpdateRunSnapshot(assign={"key": Value(string="a")}) + element = single_operation(update, sequence_id=2) + sender.operations_queue.get.side_effect = [ + BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), + queue.Empty, + ] + + # and + sender.backend.submit.side_effect = [ + response([], status_code=408), + response([], status_code=429), + response(["a"], status_code=200), + ] + + # when + sender.sender_thread.work() + + # then + assert sender.backend.submit.call_count == 3 def test_status_tracking_thread_processes_single_element(): @@ -311,3 +298,5 @@ def test_status_tracking_thread_fails_on_regular_error(): assert backend.check_batch.call_count == 1 assert last_ack_seq.value == 0 assert last_ack_timestamp.value == 0 + +