Skip to content

Commit

Permalink
Merge pull request #99 from neptune-ai/ms/error-tracking-retryable-error
Browse files Browse the repository at this point in the history
fix: Retry on NeptuneRetryableError. Call on_async_lag_callback
  • Loading branch information
michalsosn authored Dec 3, 2024
2 parents 6e4b449 + f046d6c commit 27c8f84
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/neptune_scale/sync/errors_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def work(self) -> None:
self._on_queue_full_callback(error, last_raised_at)
elif isinstance(error, NeptuneConnectionLostError):
self._on_network_error_callback(error, last_raised_at)
elif isinstance(error, NeptuneAsyncLagThresholdExceeded):
self._on_async_lag_callback()
elif isinstance(error, NeptuneScaleWarning):
self._on_warning_callback(error, last_raised_at)
elif isinstance(error, NeptuneScaleError):
self._on_error_callback(error, last_raised_at)
elif isinstance(error, NeptuneAsyncLagThresholdExceeded):
self._on_async_lag_callback()
else:
self._on_error_callback(NeptuneUnexpectedError(reason=str(error)), last_raised_at)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def resources(self) -> tuple[Resource, ...]:
return (self._backend,)
return ()

@backoff.on_exception(backoff.expo, NeptuneConnectionLostError, max_time=MAX_REQUEST_RETRY_SECONDS)
@backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS)
@with_api_errors_handling
def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:
if self._backend is None:
Expand Down
35 changes: 29 additions & 6 deletions tests/unit/test_errors_monitor.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,55 @@
from threading import Event
from typing import Optional
from unittest.mock import Mock

import pytest

from neptune_scale.exceptions import (
NeptuneAsyncLagThresholdExceeded,
NeptuneConnectionLostError,
NeptuneOperationsQueueMaxSizeExceeded,
NeptuneRetryableError,
NeptuneScaleError,
NeptuneScaleWarning,
NeptuneSeriesPointDuplicate,
)
from neptune_scale.sync.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
)


def test_errors_monitor():
@pytest.mark.parametrize(
["error", "callback_name"],
[
(NeptuneScaleError("error1"), "on_error_callback"),
(NeptuneRetryableError("error1"), "on_error_callback"),
(ValueError("error2"), "on_error_callback"),
(NeptuneScaleWarning("error3"), "on_warning_callback"),
(NeptuneSeriesPointDuplicate("error4"), "on_warning_callback"),
(NeptuneOperationsQueueMaxSizeExceeded("error5"), "on_queue_full_callback"),
(NeptuneConnectionLostError("error6"), "on_network_error_callback"),
(NeptuneAsyncLagThresholdExceeded("error7"), "on_async_lag_callback"),
],
)
def test_errors_monitor_callbacks_called(error, callback_name):
# given
callback = Mock()

# Synchronization event
callback_called = Event()

# Modify the callback to set the event when called
def callback_with_event(exception: BaseException, last_called: Optional[float]) -> None:
callback(exception, last_called)
def callback_with_event(*args, **kwargs) -> None:
callback()
callback_called.set()

# and
errors_queue = ErrorsQueue()
errors_monitor = ErrorsMonitor(errors_queue=errors_queue, on_error_callback=callback_with_event)
errors_monitor = ErrorsMonitor(**{"errors_queue": errors_queue, callback_name: callback_with_event})
errors_monitor.start()

# when
errors_queue.put(ValueError("error1"))
errors_queue.put(error)
errors_queue.flush()
errors_monitor.wake_up()

Expand Down
180 changes: 180 additions & 0 deletions tests/unit/test_sync_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import queue
import time
from typing import List
from unittest.mock import Mock

import pytest
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import SubmitResponse
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.exceptions import NeptuneSynchronizationStopped
from neptune_scale.sync.queue_element import (
BatchedOperations,
SingleOperation,
)
from neptune_scale.sync.sync_process import SenderThread
from neptune_scale.util.shared_var import SharedInt


def response(request_ids: List[str], status_code: int = 200):
body = SubmitResponse(request_ids=request_ids, request_id=request_ids[-1] if request_ids else None)
content = body.SerializeToString()
return Mock(status_code=status_code, content=content, parsed=body)


def single_operation(update: UpdateRunSnapshot, sequence_id):
operation = RunOperation(update=update)
return SingleOperation(
sequence_id=sequence_id,
timestamp=time.process_time(),
operation=operation.SerializeToString(),
is_batchable=True,
metadata_size=update.ByteSize(),
batch_key=None,
)


def test_sender_thread_work_finishes_when_queue_empty():
# given
operations_queue = Mock()
status_tracking_queue = Mock()
errors_queue = Mock()
last_queue_seq = SharedInt(initial_value=0)
backend = Mock()
sender_thread = SenderThread(
api_token="",
family="",
operations_queue=operations_queue,
status_tracking_queue=status_tracking_queue,
errors_queue=errors_queue,
last_queued_seq=last_queue_seq,
mode="disabled",
)
sender_thread._backend = backend

# and
operations_queue.get.side_effect = queue.Empty

# when
sender_thread.work()

# then
assert True


def test_sender_thread_processes_single_element():
# given
operations_queue = Mock()
status_tracking_queue = Mock()
errors_queue = Mock()
last_queue_seq = SharedInt(initial_value=0)
backend = Mock()
sender_thread = SenderThread(
api_token="",
family="",
operations_queue=operations_queue,
status_tracking_queue=status_tracking_queue,
errors_queue=errors_queue,
last_queued_seq=last_queue_seq,
mode="disabled",
)
sender_thread._backend = backend

# and
update = UpdateRunSnapshot(assign={"key": Value(string="a")})
element = single_operation(update, sequence_id=2)
operations_queue.get.side_effect = [
BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation),
queue.Empty,
]

# and
backend.submit.side_effect = [response(["1"])]

# when
sender_thread.work()

# then
assert backend.submit.call_count == 1


def test_sender_thread_processes_element_on_single_retryable_error():
# given
operations_queue = Mock()
status_tracking_queue = Mock()
errors_queue = Mock()
last_queue_seq = SharedInt(initial_value=0)
backend = Mock()
sender_thread = SenderThread(
api_token="",
family="",
operations_queue=operations_queue,
status_tracking_queue=status_tracking_queue,
errors_queue=errors_queue,
last_queued_seq=last_queue_seq,
mode="disabled",
)
sender_thread._backend = backend

# and
update = UpdateRunSnapshot(assign={"key": Value(string="a")})
element = single_operation(update, sequence_id=2)
operations_queue.get.side_effect = [
BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation),
queue.Empty,
]

# and
backend.submit.side_effect = [
response([], status_code=503),
response(["a"], status_code=200),
]

# when
sender_thread.work()

# then
assert backend.submit.call_count == 2


def test_sender_thread_fails_on_regular_error():
# given
operations_queue = Mock()
status_tracking_queue = Mock()
errors_queue = Mock()
last_queue_seq = SharedInt(initial_value=0)
backend = Mock()
sender_thread = SenderThread(
api_token="",
family="",
operations_queue=operations_queue,
status_tracking_queue=status_tracking_queue,
errors_queue=errors_queue,
last_queued_seq=last_queue_seq,
mode="disabled",
)
sender_thread._backend = backend

# and
update = UpdateRunSnapshot(assign={"key": Value(string="a")})
element = single_operation(update, sequence_id=2)
operations_queue.get.side_effect = [
BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation),
queue.Empty,
]

# and
backend.submit.side_effect = [
response([], status_code=200),
]

# when
with pytest.raises(NeptuneSynchronizationStopped):
sender_thread.work()

# then should throw NeptuneInternalServerError
errors_queue.put.assert_called_once()

0 comments on commit 27c8f84

Please sign in to comment.