Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error handling during wait*() and Run.__init__() #118

Merged
merged 11 commits into from
Jan 14, 2025
124 changes: 93 additions & 31 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
__all__ = ["Run"]

import atexit
import math
import os
import threading
import time
Expand Down Expand Up @@ -50,6 +51,8 @@
from neptune_scale.exceptions import (
NeptuneApiTokenNotProvided,
NeptuneProjectNotProvided,
NeptuneScaleError,
NeptuneSynchronizationStopped,
)
from neptune_scale.net.serialization import (
datetime_to_proto,
Expand All @@ -58,6 +61,7 @@
from neptune_scale.sync.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
default_error_callback,
)
from neptune_scale.sync.lag_tracking import LagTracker
from neptune_scale.sync.operations_queue import OperationsQueue
Expand All @@ -70,6 +74,7 @@
STOP_MESSAGE_FREQUENCY,
)
from neptune_scale.sync.sync_process import SyncProcess
from neptune_scale.types import RunCallback
from neptune_scale.util.abstract import (
Resource,
WithResources,
Expand Down Expand Up @@ -207,6 +212,12 @@ def __init__(
# This flag is used to signal that we're closed or being closed (and most likely waiting for sync), and no
# new data should be logged.
self._is_closing = False
# This is used to signal that the close/termination operation is completed and block user code until it is so
self._close_completed = threading.Event()
# Thread that initially called _close()
self._closing_thread: Optional[threading.Thread] = None
# Mark that __init__() is finished, so that the error callback can act upon it (see _wrap_error_callback())
self._init_completed = False

self._project: str = input_project
self._run_id: str = run_id
Expand All @@ -220,11 +231,13 @@ def __init__(
self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue)

self._errors_queue: ErrorsQueue = ErrorsQueue()
# Note that for the duration of __init__ we use a special error callback that
# is guaranteed to terminate the run in case of an error.
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
on_queue_full_callback=on_queue_full_callback,
on_network_error_callback=on_network_error_callback,
on_error_callback=on_error_callback,
on_error_callback=self._wrap_error_callback(on_error_callback),
on_warning_callback=on_warning_callback,
)

Expand Down Expand Up @@ -271,14 +284,36 @@ def __init__(
fork_run_id=fork_run_id,
fork_step=fork_step,
)
self.wait_for_processing(verbose=False)

# Wait in short periods to return from __init__ if run creation fails
# and Run.terminate() is called in self._wrap_error_callback()
while not self._is_closing:
if self.wait_for_processing(verbose=False, timeout=1):
break

with self._lock:
self._init_completed = True

def _wrap_error_callback(self, user_callback: Optional[RunCallback]) -> RunCallback:
"""Wrapp the provided user error callback so that we can handle errors during __init__()"""
callback = user_callback or default_error_callback

def wrapped_callback(error: BaseException, last_seen_at: Optional[float]) -> None:
with self._lock:
if not self._init_completed:
self.terminate()

callback(error, last_seen_at)

return wrapped_callback

def _on_child_link_closed(self, _: ProcessLink) -> None:
with self._lock:
if not self._is_closing:
logger.error("Child process closed unexpectedly. Terminating.")
self._is_closing = True
self.terminate()

# Make sure all the error handling is done from a single thread - self._errors_monitor
self._errors_queue.put(NeptuneSynchronizationStopped())

@property
def resources(self) -> tuple[Resource, ...]:
Expand All @@ -297,12 +332,21 @@ def resources(self) -> tuple[Resource, ...]:

def _close(self, *, wait: bool = True) -> None:
with self._lock:
if self._is_closing:
return

self._is_closing = True
was_closing = self._is_closing
if not self._is_closing:
self._is_closing = True
self._closing_thread = threading.current_thread()

if was_closing:
logger.debug("Waiting for run to be closed from a different thread")
# TODO: we should probably have a reasonable timeout here, same one as a default one in
# wait_for_processing(). Or just accept indefinite wait here in both cases.
# if not self._close_completed.wait(timeout=...):
if not self._close_completed.wait():
raise NeptuneScaleError(reason="Run close operation timed out")
return

logger.debug(f"Run is closing, wait={wait}")
logger.debug(f"Run is closing, wait={wait}")

if self._sync_process.is_alive():
if wait:
Expand All @@ -324,6 +368,8 @@ def _close(self, *, wait: bool = True) -> None:
if threading.current_thread() != self._errors_monitor:
self._errors_monitor.join()

self._close_completed.set()

super().close()

def terminate(self) -> None:
Expand Down Expand Up @@ -565,30 +611,37 @@ def _wait(
wait_seq: SharedInt,
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
) -> bool:
if verbose:
logger.info(f"Waiting for all operations to be {phrase}")

if timeout is None and verbose:
logger.warning("No timeout specified. Waiting indefinitely")
if timeout is None:
if verbose:
logger.warning("No timeout specified. Waiting indefinitely")
timeout = math.inf

begin_time = time.time()
wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time
begin_time = time.monotonic()
wait_time = min(sleep_time, timeout)
last_print_timestamp: Optional[float] = None

while True:
try:
with self._lock:
if not self._sync_process.is_alive():
if verbose and not self._is_closing:
# TODO: error out here?
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running
is_closing = self._is_closing

# Handle the case where we get notified on `wait_seq` before we actually wait.
# Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens.
if wait_seq.value >= self._operations_queue.last_sequence_id:
break
return True

if is_closing:
if threading.current_thread() != self._closing_thread:
if verbose:
logger.warning("Waiting interrupted by run termination")

self._close_completed.wait(wait_time)

return False

with wait_seq:
wait_seq.wait(timeout=wait_time)
Expand Down Expand Up @@ -617,48 +670,57 @@ def _wait(
last_print=last_print_timestamp,
verbose=verbose,
)
else:
# Reaching the last queued sequence ID means that all operations were submitted
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break
# Reaching the last queued sequence ID means that all operations were submitted
elif value >= last_queued_sequence_id:
if verbose:
logger.info(f"All operations were {phrase}")
return True

if time.monotonic() - begin_time > timeout:
return False
except KeyboardInterrupt:
if verbose:
logger.warning("Waiting interrupted by user")
return
return False

if verbose:
logger.info(f"All operations were {phrase}")
return False

def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> bool:
"""
Waits until all metadata is submitted to Neptune for processing.

When submitted, the data is not yet saved in Neptune until fully processed.
See wait_for_processing().

Returns True if all currently queued operations were submitted, False if timeout was reached
or Run is closing.

Args:
timeout (float, optional): In seconds, the maximum time to wait for submission.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
return self._wait(
phrase="submitted",
sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
wait_seq=self._last_queued_seq,
timeout=timeout,
verbose=verbose,
)

def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> bool:
"""
Waits until all metadata is processed by Neptune.

Once the call is complete, the data is saved in Neptune.

Returns True if all currently queued operations were processed, False if timeout was reached
or Run is closing.

Args:
timeout (float, optional): In seconds, the maximum time to wait for processing.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
return self._wait(
phrase="processed",
sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
wait_seq=self._last_ack_seq,
Expand Down
5 changes: 4 additions & 1 deletion src/neptune_scale/sync/errors_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def __init__(
def get_next(self) -> Optional[BaseException]:
try:
return self._errors_queue.get(block=False)
except queue.Empty:
except (queue.Empty, ValueError):
# Catch ValueError which is raised when reading from an already closed queue.
# This happens sometimes on abnormal termination, so silence the error message.
# TODO: we should synchronize here properly instead
return None

def work(self) -> None:
Expand Down
21 changes: 21 additions & 0 deletions src/neptune_scale/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Copyright (c) 2025, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (
Callable,
Optional,
)

RunCallback = Callable[[BaseException, Optional[float]], None]
3 changes: 3 additions & 0 deletions src/neptune_scale/util/process_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def join(self, timeout: Optional[float] = None) -> None:
if not self._worker:
raise RuntimeError("You must call start() before calling join()")

if threading.current_thread() == self._worker:
raise RuntimeError("Cannot join() from a callback")

self._worker.join(timeout=timeout)

def send(self, message: Any) -> None:
Expand Down
Loading
Loading