Skip to content

Commit

Permalink
Merge pull request #109 from neptune-ai/ms/check_batch-backoff
Browse files Browse the repository at this point in the history
fix: retry on all NeptuneRetryableError in StatusTrackingThread.check_batch
  • Loading branch information
michalsosn authored Jan 14, 2025
2 parents 259f789 + 71e167c commit 5d9fcbd
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def get_next(self) -> Optional[list[StatusTrackingElement]]:
except queue.Empty:
return None

@backoff.on_exception(backoff.expo, NeptuneConnectionLostError, max_time=MAX_REQUEST_RETRY_SECONDS)
@backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS)
@with_api_errors_handling
def check_batch(self, *, request_ids: list[str]) -> Optional[BulkRequestStatus]:
if self._backend is None:
Expand Down
156 changes: 145 additions & 11 deletions tests/unit/test_sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,48 @@
from unittest.mock import Mock

import pytest
from neptune_api.proto.google_rpc.code_pb2 import Code
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import SubmitResponse
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import (
BulkRequestStatus,
SubmitResponse,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus

from neptune_scale.exceptions import NeptuneSynchronizationStopped
from neptune_scale.sync.queue_element import (
BatchedOperations,
SingleOperation,
)
from neptune_scale.sync.sync_process import SenderThread
from neptune_scale.util.shared_var import SharedInt
from neptune_scale.sync.sync_process import (
SenderThread,
StatusTrackingElement,
StatusTrackingThread,
)
from neptune_scale.util.shared_var import (
SharedFloat,
SharedInt,
)


def response(request_ids: list[str], status_code: int = 200):
def status_response(code_count_batch: list[dict[Code.ValueType, int]], status_code: int = 200):
body = BulkRequestStatus(
statuses=[
RequestStatus(
code_by_count=[RequestStatus.CodeByCount(code=code, count=count) for code, count in code_counts.items()]
)
for code_counts in code_count_batch
]
)
content = body.SerializeToString()
return Mock(status_code=status_code, content=content, parsed=body)


def submit_response(request_ids: list[str], status_code: int = 200):
body = SubmitResponse(request_ids=request_ids, request_id=request_ids[-1] if request_ids else None)
content = body.SerializeToString()
return Mock(status_code=status_code, content=content, parsed=body)
Expand Down Expand Up @@ -90,7 +115,7 @@ def test_sender_thread_processes_single_element(sender):
]

# and
sender.backend.submit.side_effect = [response(["1"])]
sender.backend.submit.side_effect = [submit_response(["1"])]

# when
sender.sender_thread.work()
Expand All @@ -110,8 +135,8 @@ def test_sender_thread_processes_element_on_single_retryable_error(sender):

# and
sender.backend.submit.side_effect = [
response([], status_code=503),
response(["a"], status_code=200),
submit_response([], status_code=503),
submit_response(["a"], status_code=200),
]

# when
Expand All @@ -132,7 +157,7 @@ def test_sender_thread_fails_on_regular_error(sender):

# and
sender.backend.submit.side_effect = [
response([], status_code=200),
submit_response([], status_code=200),
]

# when
Expand All @@ -154,13 +179,122 @@ def test_sender_thread_processes_element_on_429_and_408_http_statuses(sender):

# and
sender.backend.submit.side_effect = [
response([], status_code=408),
response([], status_code=429),
response(["a"], status_code=200),
submit_response([], status_code=408),
submit_response([], status_code=429),
submit_response(["a"], status_code=200),
]

# when
sender.sender_thread.work()

# then
assert sender.backend.submit.call_count == 3


def test_status_tracking_thread_processes_single_element():
# given
status_tracking_queue = Mock()
errors_queue = Mock()
last_ack_seq = SharedInt(initial_value=0)
last_ack_timestamp = SharedFloat(initial_value=0)
backend = Mock()
status_tracking_thread = StatusTrackingThread(
api_token="",
mode="disabled",
project="",
errors_queue=errors_queue,
status_tracking_queue=status_tracking_queue,
last_ack_seq=last_ack_seq,
last_ack_timestamp=last_ack_timestamp,
)
status_tracking_thread._backend = backend

# and
element = StatusTrackingElement(sequence_id=1, timestamp=time.process_time(), request_id="a")
status_tracking_queue.peek.side_effect = [[element], queue.Empty]

# and
backend.check_batch.side_effect = [status_response(code_count_batch=[{"OK": 1}])]

# when
status_tracking_thread.work()

# then
assert backend.check_batch.call_count == 1
assert last_ack_seq.value == 1
assert last_ack_timestamp.value == element.timestamp


def test_status_tracking_thread_processes_element_on_single_retryable_error():
# given
status_tracking_queue = Mock()
errors_queue = Mock()
last_ack_seq = SharedInt(initial_value=0)
last_ack_timestamp = SharedFloat(initial_value=0)
backend = Mock()
status_tracking_thread = StatusTrackingThread(
api_token="",
mode="disabled",
project="",
errors_queue=errors_queue,
status_tracking_queue=status_tracking_queue,
last_ack_seq=last_ack_seq,
last_ack_timestamp=last_ack_timestamp,
)
status_tracking_thread._backend = backend

# and
element = StatusTrackingElement(sequence_id=1, timestamp=time.process_time(), request_id="a")
status_tracking_queue.peek.side_effect = [[element], queue.Empty]

# and
backend.check_batch.side_effect = [
status_response(code_count_batch=[], status_code=408),
status_response(code_count_batch=[{"OK": 1}]),
]

# when
status_tracking_thread.work()

# then
assert backend.check_batch.call_count == 2
assert last_ack_seq.value == 1
assert last_ack_timestamp.value == element.timestamp


def test_status_tracking_thread_fails_on_regular_error():
# given
status_tracking_queue = Mock()
errors_queue = Mock()
last_ack_seq = SharedInt(initial_value=0)
last_ack_timestamp = SharedFloat(initial_value=0)
backend = Mock()
status_tracking_thread = StatusTrackingThread(
api_token="",
mode="disabled",
project="",
errors_queue=errors_queue,
status_tracking_queue=status_tracking_queue,
last_ack_seq=last_ack_seq,
last_ack_timestamp=last_ack_timestamp,
)
status_tracking_thread._backend = backend

# and
element = StatusTrackingElement(sequence_id=1, timestamp=time.process_time(), request_id="a")
status_tracking_queue.peek.side_effect = [[element], queue.Empty]

# and
backend.check_batch.side_effect = [
status_response(code_count_batch=[], status_code=403),
]

# when
with pytest.raises(NeptuneSynchronizationStopped):
status_tracking_thread.work()

# then
errors_queue.put.assert_called_once()
assert backend.check_batch.call_count == 1
assert last_ack_seq.value == 0
assert last_ack_timestamp.value == 0

0 comments on commit 5d9fcbd

Please sign in to comment.