Skip to content

Commit

Permalink
Handle Run creation errors in __init__.
Browse files Browse the repository at this point in the history
Also make `wait*` return `True` if all operations were processed, `False` on timeout
  • Loading branch information
kgodlewski authored and PatrykGala committed Jan 14, 2025
1 parent cdd03e8 commit 4747e9d
Showing 1 changed file with 51 additions and 14 deletions.
65 changes: 51 additions & 14 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__all__ = ["Run"]

import atexit
import functools
import math
import os
import threading
Expand Down Expand Up @@ -58,6 +59,7 @@
STOP_MESSAGE_FREQUENCY,
)
from neptune_scale.sync.sync_process import SyncProcess
from neptune_scale.types import RunCallback
from neptune_scale.util.abstract import (
Resource,
WithResources,
Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(
self._is_closing = False
# This is used to signal that the close/termination operation is completed and block user code until it is so
self._close_completed = threading.Event()
# Thread that initially called _close()
self._closing_thread: Optional[threading.Thread] = None

self._project: str = input_project
self._run_id: str = run_id
Expand All @@ -210,6 +214,8 @@ def __init__(
self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue)

self._errors_queue: ErrorsQueue = ErrorsQueue()
# Note that for the duration of __init__ we use a special error callback that
# is guaranteed to terminate the run in case of an error.
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
on_queue_full_callback=on_queue_full_callback,
Expand All @@ -218,6 +224,12 @@ def __init__(
on_warning_callback=on_warning_callback,
)

# Grab it like that, in case on_error_callback is None -- we will get the default one then
orig_error_callback = self._errors_monitor.on_error_callback
self._errors_monitor.on_error_callback = functools.partial(
self._initialization_error_callback, orig_error_callback
)

self._last_queued_seq = SharedInt(-1)
self._last_ack_seq = SharedInt(-1)
self._last_ack_timestamp = SharedFloat(-1)
Expand Down Expand Up @@ -261,7 +273,21 @@ def __init__(
fork_run_id=fork_run_id,
fork_step=fork_step,
)
self.wait_for_processing(verbose=False)

# Wait in short periods to return from __init__ if run creation fails
# and Run.terminate() is called in self._initialization_error_callback
while not self._is_closing:
if self.wait_for_processing(verbose=False, timeout=1):
break

# Bring back the originally requested error callback
self._errors_monitor.on_error_callback = orig_error_callback

def _initialization_error_callback(
self, user_callback: RunCallback, error: BaseException, last_seen_at: Optional[float]
) -> None:
self.terminate()
user_callback(error, last_seen_at)

def _on_child_link_closed(self, _: ProcessLink) -> None:
with self._lock:
Expand Down Expand Up @@ -289,7 +315,9 @@ def resources(self) -> tuple[Resource, ...]:
def _close(self, *, wait: bool = True) -> None:
with self._lock:
was_closing = self._is_closing
self._is_closing = True
if not self._is_closing:
self._is_closing = True
self._closing_thread = threading.current_thread()

if was_closing:
logger.debug("Waiting for run to be closed from a different thread")
Expand Down Expand Up @@ -565,7 +593,7 @@ def _wait(
wait_seq: SharedInt,
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
) -> bool:
if verbose:
logger.info(f"Waiting for all operations to be {phrase}")

Expand All @@ -588,11 +616,12 @@ def _wait(
if wait_seq.value >= self._operations_queue.last_sequence_id:
break

if is_closing:
if is_closing and threading.current_thread() != self._closing_thread:
if verbose:
logger.info("Waiting interrupted by Run being closed.")
self._close_completed.wait(timeout)
return
logger.warning("Waiting interrupted by run termination")

self._close_completed.wait(wait_time)
return False

with wait_seq:
wait_seq.wait(timeout=wait_time)
Expand Down Expand Up @@ -625,45 +654,53 @@ def _wait(
elif value >= last_queued_sequence_id:
if verbose:
logger.info(f"All operations were {phrase}")
break
return True

if time.monotonic() - begin_time > timeout:
break
return False
except KeyboardInterrupt:
if verbose:
logger.warning("Waiting interrupted by user")
return
return False

def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
return False

def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> bool:
"""
Waits until all metadata is submitted to Neptune for processing.
When submitted, the data is not yet saved in Neptune until fully processed.
See wait_for_processing().
Returns True if all currently queued operations were submitted, False if timeout was reached
or Run is closing.
Args:
timeout (float, optional): In seconds, the maximum time to wait for submission.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
return self._wait(
phrase="submitted",
sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
wait_seq=self._last_queued_seq,
timeout=timeout,
verbose=verbose,
)

def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> bool:
"""
Waits until all metadata is processed by Neptune.
Once the call is complete, the data is saved in Neptune.
Returns True if all currently queued operations were processed, False if timeout was reached
or Run is closing.
Args:
timeout (float, optional): In seconds, the maximum time to wait for processing.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
return self._wait(
phrase="processed",
sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
wait_seq=self._last_ack_seq,
Expand Down

0 comments on commit 4747e9d

Please sign in to comment.