Skip to content

Commit

Permalink
Merge pull request #108 from neptune-ai/kg/operation-dispatcher-thread
Browse files Browse the repository at this point in the history
Make `OperationDispatcherThread` accept multiple consumers
  • Loading branch information
PatrykGala authored Jan 14, 2025
2 parents 4c2e36e + 98e24de commit 12faace
Showing 1 changed file with 43 additions and 26 deletions.
69 changes: 43 additions & 26 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import queue
import signal
import threading
from collections.abc import Iterable
from multiprocessing import (
Process,
Queue,
Expand All @@ -16,6 +17,7 @@
Literal,
NamedTuple,
Optional,
Protocol,
TypeVar,
)

Expand All @@ -38,7 +40,6 @@
NeptuneConnectionLostError,
NeptuneFloatValueNanInfUnsupported,
NeptuneInternalServerError,
NeptuneOperationsQueueMaxSizeExceeded,
NeptuneProjectInvalidName,
NeptuneProjectNotFound,
NeptuneRetryableError,
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(
) -> None:
super().__init__(name="SyncProcess")

self._external_operations_queue: Queue[SingleOperation] = operations_queue
self._input_operations_queue: Queue[SingleOperation] = operations_queue
self._errors_queue: ErrorsQueue = errors_queue
self._process_link: ProcessLink = process_link
self._api_token: str = api_token
Expand Down Expand Up @@ -213,7 +214,7 @@ def run(self) -> None:
family=self._family,
api_token=self._api_token,
errors_queue=self._errors_queue,
external_operations_queue=self._external_operations_queue,
input_queue=self._input_operations_queue,
last_queued_seq=self._last_queued_seq,
last_ack_seq=self._last_ack_seq,
max_queue_size=self._max_queue_size,
Expand All @@ -235,6 +236,10 @@ def run(self) -> None:
logger.info("Data synchronization finished")


class SupportsPutNowait(Protocol):
def put_nowait(self, element: SingleOperation) -> None: ...


class SyncProcessWorker(WithResources):
def __init__(
self,
Expand All @@ -244,7 +249,7 @@ def __init__(
family: str,
mode: Literal["async", "disabled"],
errors_queue: ErrorsQueue,
external_operations_queue: multiprocessing.Queue[SingleOperation],
input_queue: multiprocessing.Queue[SingleOperation],
last_queued_seq: SharedInt,
last_ack_seq: SharedInt,
last_ack_timestamp: SharedFloat,
Expand All @@ -263,11 +268,7 @@ def __init__(
last_queued_seq=last_queued_seq,
mode=mode,
)
self._external_to_internal_thread = InternalQueueFeederThread(
external=external_operations_queue,
internal=self._internal_operations_queue,
errors_queue=self._errors_queue,
)

self._status_tracking_thread = StatusTrackingThread(
api_token=api_token,
mode=mode,
Expand All @@ -278,13 +279,19 @@ def __init__(
last_ack_timestamp=last_ack_timestamp,
)

self._operation_dispatcher_thread = OperationDispatcherThread(
input_queue=input_queue,
consumers=[self._internal_operations_queue],
errors_queue=self._errors_queue,
)

@property
def threads(self) -> tuple[Daemon, ...]:
return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread
return self._operation_dispatcher_thread, self._sync_thread, self._status_tracking_thread

@property
def resources(self) -> tuple[Resource, ...]:
return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread
return self._operation_dispatcher_thread, self._sync_thread, self._status_tracking_thread

def interrupt(self) -> None:
for thread in self.threads:
Expand All @@ -304,17 +311,23 @@ def join(self, timeout: Optional[int] = None) -> None:
thread.join(timeout=timeout)


class InternalQueueFeederThread(Daemon, Resource):
class OperationDispatcherThread(Daemon, Resource):
"""Reads incoming messages from a multiprocessing.Queue, and dispatches them to a list of consumers,
which can be of type `queue.Queue`, but also any other object that supports put_nowait() method.
If any of the consumers' put_nowait() raises queue.Full, the thread will stop processing further operations.
"""

def __init__(
self,
external: multiprocessing.Queue[SingleOperation],
internal: AggregatingQueue,
input_queue: multiprocessing.Queue[SingleOperation],
consumers: Iterable[SupportsPutNowait],
errors_queue: ErrorsQueue,
) -> None:
super().__init__(name="InternalQueueFeederThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME)
super().__init__(name="OperationDispatcherThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME)

self._external: multiprocessing.Queue[SingleOperation] = external
self._internal: AggregatingQueue = internal
self._input_queue: multiprocessing.Queue[SingleOperation] = input_queue
self._consumers = tuple(consumers)
self._errors_queue: ErrorsQueue = errors_queue

self._latest_unprocessed: Optional[SingleOperation] = None
Expand All @@ -324,7 +337,7 @@ def get_next(self) -> Optional[SingleOperation]:
return self._latest_unprocessed

try:
self._latest_unprocessed = self._external.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME)
self._latest_unprocessed = self._input_queue.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME)
return self._latest_unprocessed
except queue.Empty:
return None
Expand All @@ -335,18 +348,22 @@ def commit(self) -> None:
def work(self) -> None:
try:
while not self._is_interrupted():
operation = self.get_next()
if operation is None:
if (operation := self.get_next()) is None:
continue

try:
self._internal.put_nowait(operation)
for consumer in self._consumers:
consumer.put_nowait(operation)
self.commit()
except queue.Full:
logger.debug("Internal queue is full (%d elements), waiting for free space", self._internal.maxsize)
self._errors_queue.put(NeptuneOperationsQueueMaxSizeExceeded(max_size=self._internal.maxsize))
# Sleep before retry
break
except queue.Full as e:
# We have two ways to deal with this situation:
# 1. Consider this a fatal error, and stop processing further operations.
# 2. Retry, assuming that any consumer that _did_ manage to receive the operation, is
# idempotent and can handle the same operation again.
#
# Currently, we choose 1.
logger.error("Operation queue overflow. Neptune will not process further operations.")
raise e
except Exception as e:
self._errors_queue.put(e)
self.interrupt()
Expand Down

0 comments on commit 12faace

Please sign in to comment.