Skip to content

Commit

Permalink
Atexit and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Aug 16, 2024
1 parent e9efaec commit 67ef786
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 30 deletions.
90 changes: 65 additions & 25 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

__all__ = ["Run"]

import atexit
import multiprocessing
import os
import threading
import time
from contextlib import AbstractContextManager
from datetime import datetime
from multiprocessing.sharedctypes import Synchronized
Expand Down Expand Up @@ -50,6 +52,8 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)


Expand Down Expand Up @@ -156,7 +160,6 @@ def __init__(

self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()
# TODO: Rethink
self._sync_process = SyncProcess(
family=self._family,
operations_queue=self._operations_queue.queue,
Expand All @@ -167,6 +170,12 @@ def __init__(
max_queue_size=max_queue_size,
)

self._errors_monitor.start()
with self._lock:
self._sync_process.start()

self._exit_func: Callable[[], None] | None = atexit.register(self._close)

if not resume:
self._create_run(
creation_time=datetime.now() if creation_time is None else creation_time,
Expand All @@ -175,12 +184,6 @@ def __init__(
from_step=from_step,
)

def __enter__(self) -> Run:
# TODO: Rethink
self._errors_monitor.start()
# self._sync_process.start()
return self

@property
def resources(self) -> tuple[Resource, ...]:
return (
Expand All @@ -189,21 +192,28 @@ def resources(self) -> tuple[Resource, ...]:
self._errors_monitor,
)

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
# TODO: Rethink
def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
# self.wait_for_submission()
# self._sync_process.terminate()
# self._sync_process.join()
print("Closing")
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self._sync_process.terminate()
self._sync_process.join()

self._errors_monitor.interrupt()
self._errors_monitor.join()
print("Errors Monitor closed")

super().close()

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
if self._exit_func is not None:
atexit.unregister(self._exit_func)
self._exit_func = None
self._close()

def _create_run(
self,
creation_time: datetime,
Expand Down Expand Up @@ -295,18 +305,48 @@ def log(
for operation in splitter:
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self) -> None:
def wait_for_submission(self, timeout: float | None = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: float | None = None
while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait()
logger.debug(
"Waiting until the sync process will populate up to %s, %s is currently processed",
self._last_put_seq.value,
last_queued_sequence_id,
)
if self._last_put_seq.value >= last_queued_sequence_id:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting. No operations were processed yet. Operations to sync: %s",
self._operations_queue.last_sequence_id,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info("Waiting. No operations were processed yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting until remaining %d operations will be synced", last_queued_sequence_id - value
)
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations were processed")
8 changes: 5 additions & 3 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,23 @@
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.components.abstract import Resource
from neptune_scale.core.logger import logger


class ApiClient(Resource):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)

logger.debug("Trying to connect to Neptune API")
config, token_urls = get_config_and_token_urls(credentials=credentials)
self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls)
logger.debug("Connected to Neptune API")

def submit(self, operation: RunOperation, family: str) -> None:
_ = submit_operation.sync(client=self._backend, body=operation, family=family)

def cleanup(self) -> None:
pass

def close(self) -> None:
logger.debug("Closing API client")
self._backend.__exit__()


Expand Down
1 change: 1 addition & 0 deletions src/neptune_scale/core/components/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self._max_queue_size: int = max_queue_size

def run(self) -> None:
logger.info("Data synchronization started")
worker = SyncProcessWorker(
family=self._family,
api_token=self._api_token,
Expand Down
2 changes: 1 addition & 1 deletion src/neptune_scale/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_logger() -> logging.Logger:

if os.environ.get(DEBUG_MODE, "False").lower() in ("true", "1"):
file_handler = logging.FileHandler(NEPTUNE_DEBUG_FILE_NAME)
file_handler.setLevel(logging.INFO)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter(DEBUG_FORMAT))
neptune_logger.addHandler(file_handler)

Expand Down
4 changes: 3 additions & 1 deletion src/neptune_scale/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
SYNC_THREAD_SLEEP_TIME = 0.1
EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME = 0.1
ERRORS_MONITOR_THREAD_SLEEP_TIME = 0.1
SHUTDOWN_TIMEOUT = 10
SHUTDOWN_TIMEOUT = 60 # 1 minute
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME = 10
STOP_MESSAGE_FREQUENCY = 5

0 comments on commit 67ef786

Please sign in to comment.