Skip to content

Commit

Permalink
Sending operations asynchronously (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Aug 21, 2024
1 parent 100a9b4 commit 5a12490
Show file tree
Hide file tree
Showing 23 changed files with 1,051 additions and 243 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ repos:
additional_dependencies:
- neptune-api==0.4.0
- more-itertools
- backoff
default_language_version:
python: python3
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ python = "^3.8"

neptune-api = "0.4.0"
more-itertools = "^10.0.0"
psutil = "^5.0.0"
backoff = "^2.0.0"

[tool.poetry]
name = "neptune-client-scale"
Expand Down Expand Up @@ -74,6 +76,8 @@ force_grid_wrap = 2

[tool.ruff]
line-length = 120
target-version = "py38"
ignore = ["UP006", "UP007"]

[tool.ruff.lint]
select = ["F", "UP"]
Expand Down
175 changes: 135 additions & 40 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,40 @@

__all__ = ["Run"]

import atexit
import multiprocessing
import os
import threading
import time
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Condition as ConditionT
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Set,
Union,
)

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import ApiClient
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.errors_monitor import ErrorsMonitor
from neptune_scale.core.components.errors_queue import ErrorsQueue
from neptune_scale.core.components.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.components.sync_process import SyncProcess
from neptune_scale.core.logger import logger
from neptune_scale.core.metadata_splitter import MetadataSplitter
from neptune_scale.core.serialization import (
datetime_to_proto,
Expand All @@ -44,6 +60,8 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)


Expand All @@ -57,15 +75,17 @@ def __init__(
*,
family: str,
run_id: str,
project: str | None = None,
api_token: str | None = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
resume: bool = False,
as_experiment: str | None = None,
creation_time: datetime | None = None,
from_run_id: str | None = None,
from_step: int | float | None = None,
mode: Literal["async", "disabled"] = "async",
as_experiment: Optional[str] = None,
creation_time: Optional[datetime] = None,
from_run_id: Optional[str] = None,
from_step: Optional[Union[int, float]] = None,
max_queue_size: int = MAX_QUEUE_SIZE,
max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None,
max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None,
on_network_error_callback: Optional[Callable[[BaseException], None]] = None,
) -> None:
"""
Initializes a run that logs the model-building metadata to Neptune.
Expand All @@ -79,15 +99,15 @@ def __init__(
api_token: Your Neptune API token. If not provided, the value of the `NEPTUNE_API_TOKEN` environment
variable is used.
resume: Whether to resume an existing run.
mode: Mode of operation. If set to "disabled", the run doesn't log any metadata.
as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run.
creation_time: Custom creation time of the run.
from_run_id: If forking from an existing run, ID of the run to fork from.
from_step: If forking from an existing run, step number to fork from.
max_queue_size: Maximum number of operations in a queue.
max_queue_size_exceeded_callback: Callback function triggered when a queue is full.
Accepts two arguments:
- Maximum size of the queue.
- Exception that made the queue full.
max_queue_size_exceeded_callback: Callback function triggered when the queue is full. The function should take the exception
that made the queue full as its argument.
on_network_error_callback: Callback function triggered when a network error occurs.
"""
verify_type("family", family, str)
verify_type("run_id", run_id, str)
Expand Down Expand Up @@ -143,13 +163,33 @@ def __init__(

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
lock=self._lock,
max_size=max_queue_size,
)
self._errors_queue: ErrorsQueue = ErrorsQueue()
self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue)
self._backend: ApiClient = ApiClient(api_token=input_api_token)
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
max_queue_size_exceeded_callback=max_queue_size_exceeded_callback,
on_network_error_callback=on_network_error_callback,
)
self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()
self._sync_process = SyncProcess(
family=self._family,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
api_token=input_api_token,
last_put_seq=self._last_put_seq,
last_put_seq_wait=self._last_put_seq_wait,
max_queue_size=max_queue_size,
mode=mode,
)

self._errors_monitor.start()
with self._lock:
self._sync_process.start()

self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close)

if not resume:
self._create_run(
Expand All @@ -159,32 +199,44 @@ def __init__(
from_step=from_step,
)

def __enter__(self) -> Run:
return self

@property
def resources(self) -> tuple[Resource, ...]:
return (
self._errors_queue,
self._operations_queue,
self._backend,
self._errors_monitor,
self._errors_queue,
)

def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self._sync_process.terminate()
self._sync_process.join()

self._errors_monitor.interrupt()
self._errors_monitor.join()

super().close()

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()
if self._exit_func is not None:
atexit.unregister(self._exit_func)
self._exit_func = None
self._close()

def _create_run(
self,
creation_time: datetime,
as_experiment: str | None,
from_run_id: str | None,
from_step: int | float | None,
as_experiment: Optional[str],
from_run_id: Optional[str],
from_step: Optional[Union[int, float]],
) -> None:
fork_point: ForkPoint | None = None
fork_point: Optional[ForkPoint] = None
if from_run_id is not None and from_step is not None:
fork_point = ForkPoint(
parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step)
Expand All @@ -200,18 +252,16 @@ def _create_run(
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | bool | int | str | datetime | list | set] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
step: Optional[Union[float, int]] = None,
timestamp: Optional[datetime] = None,
fields: Optional[Dict[str, Union[float, bool, int, str, datetime, list, set]]] = None,
metrics: Optional[Dict[str, float]] = None,
add_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None,
remove_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None,
) -> None:
"""
Logs the specified metadata to Neptune.
Expand Down Expand Up @@ -268,6 +318,51 @@ def log(
)

for operation in splitter:
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self, timeout: Optional[float] = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: Optional[float] = None
while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting. No operations processed yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info("Waiting. No operations processed yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting for remaining %d operation(s) to be synced",
last_queued_sequence_id - value + 1,
)
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations processed")
Loading

0 comments on commit 5a12490

Please sign in to comment.