Skip to content

Commit

Permalink
Added async flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Aug 19, 2024
1 parent e35876c commit cdfecb0
Show file tree
Hide file tree
Showing 16 changed files with 754 additions and 133 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ python = "^3.8"

neptune-api = "0.4.0"
more-itertools = "^10.0.0"
psutil = "^5.0.0"

[tool.poetry]
name = "neptune-client-scale"
Expand Down
126 changes: 103 additions & 23 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@

__all__ = ["Run"]

import atexit
import multiprocessing
import os
import threading
import time
from contextlib import AbstractContextManager
from datetime import datetime
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Condition as ConditionT
from typing import Callable

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import ApiClient
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.errors_monitor import ErrorsMonitor
from neptune_scale.core.components.errors_queue import ErrorsQueue
from neptune_scale.core.components.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.components.sync_process import SyncProcess
from neptune_scale.core.logger import logger
from neptune_scale.core.metadata_splitter import MetadataSplitter
from neptune_scale.core.serialization import (
datetime_to_proto,
Expand All @@ -44,6 +52,8 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)


Expand All @@ -65,7 +75,7 @@ def __init__(
from_run_id: str | None = None,
from_step: int | float | None = None,
max_queue_size: int = MAX_QUEUE_SIZE,
max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None,
max_queue_size_exceeded_callback: Callable[[BaseException], None] | None = None,
) -> None:
"""
Initializes a run that logs the model-building metadata to Neptune.
Expand All @@ -84,10 +94,8 @@ def __init__(
from_run_id: If forking from an existing run, ID of the run to fork from.
from_step: If forking from an existing run, step number to fork from.
max_queue_size: Maximum number of operations in a queue.
max_queue_size_exceeded_callback: Callback function triggered when a queue is full.
Accepts two arguments:
- Maximum size of the queue.
- Exception that made the queue full.
max_queue_size_exceeded_callback: Callback function triggered when a queue is full. Accepts the exception
that made the queue full.
"""
verify_type("family", family, str)
verify_type("run_id", run_id, str)
Expand Down Expand Up @@ -140,13 +148,31 @@ def __init__(

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
lock=self._lock,
max_size=max_queue_size,
)
self._errors_queue: ErrorsQueue = ErrorsQueue()
self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue)
self._backend: ApiClient = ApiClient(api_token=input_api_token)
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
max_queue_size_exceeded_callback=max_queue_size_exceeded_callback,
)
self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()
self._sync_process = SyncProcess(
family=self._family,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
api_token=input_api_token,
last_put_seq=self._last_put_seq,
last_put_seq_wait=self._last_put_seq_wait,
max_queue_size=max_queue_size,
)

self._errors_monitor.start()
with self._lock:
self._sync_process.start()

self._exit_func: Callable[[], None] | None = atexit.register(self._close)

if not resume:
self._create_run(
Expand All @@ -156,23 +182,35 @@ def __init__(
from_step=from_step,
)

def __enter__(self) -> Run:
return self

@property
def resources(self) -> tuple[Resource, ...]:
return (
self._errors_queue,
self._operations_queue,
self._backend,
self._errors_monitor,
self._errors_queue,
)

def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self._sync_process.terminate()
self._sync_process.join()

self._errors_monitor.interrupt()
self._errors_monitor.join()

super().close()

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()
if self._exit_func is not None:
atexit.unregister(self._exit_func)
self._exit_func = None
self._close()

def _create_run(
self,
Expand All @@ -197,9 +235,7 @@ def _create_run(
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def log(
self,
Expand Down Expand Up @@ -265,6 +301,50 @@ def log(
)

for operation in splitter:
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self, timeout: float | None = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: float | None = None
while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting. No operations were processed yet. Operations to sync: %s",
self._operations_queue.last_sequence_id,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info("Waiting. No operations were processed yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting until remaining %d operations will be synced", last_queued_sequence_id - value
)
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations were processed")
14 changes: 9 additions & 5 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,28 @@
ClientConfig,
Error,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.types import Response

from neptune_scale.core.components.abstract import Resource
from neptune_scale.core.logger import logger


class ApiClient(Resource):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)

logger.debug("Trying to connect to Neptune API")
config, token_urls = get_config_and_token_urls(credentials=credentials)
self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls)
logger.debug("Connected to Neptune API")

def submit(self, operation: RunOperation, family: str) -> None:
_ = submit_operation.sync(client=self._backend, family=family, body=operation)

def cleanup(self) -> None:
pass
def submit(self, operation: RunOperation, family: str) -> Response[RequestId]:
return submit_operation.sync_detailed(client=self._backend, body=operation, family=family)

def close(self) -> None:
logger.debug("Closing API client")
self._backend.__exit__()


Expand Down
4 changes: 2 additions & 2 deletions src/neptune_scale/core/components/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __exit__(


class Resource(AutoCloseable):
@abstractmethod
def cleanup(self) -> None: ...
def cleanup(self) -> None:
pass

def flush(self) -> None:
pass
Expand Down
46 changes: 0 additions & 46 deletions src/neptune_scale/core/components/errors_monitor.py

This file was deleted.

24 changes: 0 additions & 24 deletions src/neptune_scale/core/components/errors_queue.py

This file was deleted.

Loading

0 comments on commit cdfecb0

Please sign in to comment.