diff --git a/src/neptune_scale/api/run.py b/src/neptune_scale/api/run.py index 9c13ffa..c805a3a 100644 --- a/src/neptune_scale/api/run.py +++ b/src/neptune_scale/api/run.py @@ -22,6 +22,7 @@ __all__ = ["Run"] import atexit +import math import os import threading import time @@ -50,6 +51,8 @@ from neptune_scale.exceptions import ( NeptuneApiTokenNotProvided, NeptuneProjectNotProvided, + NeptuneScaleError, + NeptuneSynchronizationStopped, ) from neptune_scale.net.serialization import ( datetime_to_proto, @@ -58,6 +61,7 @@ from neptune_scale.sync.errors_tracking import ( ErrorsMonitor, ErrorsQueue, + default_error_callback, ) from neptune_scale.sync.lag_tracking import LagTracker from neptune_scale.sync.operations_queue import OperationsQueue @@ -70,6 +74,7 @@ STOP_MESSAGE_FREQUENCY, ) from neptune_scale.sync.sync_process import SyncProcess +from neptune_scale.types import RunCallback from neptune_scale.util.abstract import ( Resource, WithResources, @@ -207,6 +212,12 @@ def __init__( # This flag is used to signal that we're closed or being closed (and most likely waiting for sync), and no # new data should be logged. self._is_closing = False + # This is used to signal that the close/termination operation is completed and block user code until it is so + self._close_completed = threading.Event() + # Thread that initially called _close() + self._closing_thread: Optional[threading.Thread] = None + # Mark that __init__() is finished, so that the error callback can act upon it (see _wrap_error_callback()) + self._init_completed = False self._project: str = input_project self._run_id: str = run_id @@ -220,11 +231,13 @@ def __init__( self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) self._errors_queue: ErrorsQueue = ErrorsQueue() + # Note that for the duration of __init__ we use a special error callback that + # is guaranteed to terminate the run in case of an error. self._errors_monitor = ErrorsMonitor( errors_queue=self._errors_queue, on_queue_full_callback=on_queue_full_callback, on_network_error_callback=on_network_error_callback, - on_error_callback=on_error_callback, + on_error_callback=self._wrap_error_callback(on_error_callback), on_warning_callback=on_warning_callback, ) @@ -271,14 +284,36 @@ def __init__( fork_run_id=fork_run_id, fork_step=fork_step, ) - self.wait_for_processing(verbose=False) + + # Wait in short periods to return from __init__ if run creation fails + # and Run.terminate() is called in self._wrap_error_callback() + while not self._is_closing: + if self.wait_for_processing(verbose=False, timeout=1): + break + + with self._lock: + self._init_completed = True + + def _wrap_error_callback(self, user_callback: Optional[RunCallback]) -> RunCallback: + """Wrapp the provided user error callback so that we can handle errors during __init__()""" + callback = user_callback or default_error_callback + + def wrapped_callback(error: BaseException, last_seen_at: Optional[float]) -> None: + with self._lock: + if not self._init_completed: + self.terminate() + + callback(error, last_seen_at) + + return wrapped_callback def _on_child_link_closed(self, _: ProcessLink) -> None: with self._lock: if not self._is_closing: logger.error("Child process closed unexpectedly. Terminating.") - self._is_closing = True - self.terminate() + + # Make sure all the error handling is done from a single thread - self._errors_monitor + self._errors_queue.put(NeptuneSynchronizationStopped()) @property def resources(self) -> tuple[Resource, ...]: @@ -297,12 +332,21 @@ def resources(self) -> tuple[Resource, ...]: def _close(self, *, wait: bool = True) -> None: with self._lock: - if self._is_closing: - return - - self._is_closing = True + was_closing = self._is_closing + if not self._is_closing: + self._is_closing = True + self._closing_thread = threading.current_thread() + + if was_closing: + logger.debug("Waiting for run to be closed from a different thread") + # TODO: we should probably have a reasonable timeout here, same one as a default one in + # wait_for_processing(). Or just accept indefinite wait here in both cases. + # if not self._close_completed.wait(timeout=...): + if not self._close_completed.wait(): + raise NeptuneScaleError(reason="Run close operation timed out") + return - logger.debug(f"Run is closing, wait={wait}") + logger.debug(f"Run is closing, wait={wait}") if self._sync_process.is_alive(): if wait: @@ -324,6 +368,8 @@ def _close(self, *, wait: bool = True) -> None: if threading.current_thread() != self._errors_monitor: self._errors_monitor.join() + self._close_completed.set() + super().close() def terminate(self) -> None: @@ -565,30 +611,37 @@ def _wait( wait_seq: SharedInt, timeout: Optional[float] = None, verbose: bool = True, - ) -> None: + ) -> bool: if verbose: logger.info(f"Waiting for all operations to be {phrase}") - if timeout is None and verbose: - logger.warning("No timeout specified. Waiting indefinitely") + if timeout is None: + if verbose: + logger.warning("No timeout specified. Waiting indefinitely") + timeout = math.inf - begin_time = time.time() - wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time + begin_time = time.monotonic() + wait_time = min(sleep_time, timeout) last_print_timestamp: Optional[float] = None while True: try: with self._lock: - if not self._sync_process.is_alive(): - if verbose and not self._is_closing: - # TODO: error out here? - logger.warning("Sync process is not running") - return # No need to wait if the sync process is not running + is_closing = self._is_closing # Handle the case where we get notified on `wait_seq` before we actually wait. # Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens. if wait_seq.value >= self._operations_queue.last_sequence_id: - break + return True + + if is_closing: + if threading.current_thread() != self._closing_thread: + if verbose: + logger.warning("Waiting interrupted by run termination") + + self._close_completed.wait(wait_time) + + return False with wait_seq: wait_seq.wait(timeout=wait_time) @@ -617,30 +670,36 @@ def _wait( last_print=last_print_timestamp, verbose=verbose, ) - else: - # Reaching the last queued sequence ID means that all operations were submitted - if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout): - break + # Reaching the last queued sequence ID means that all operations were submitted + elif value >= last_queued_sequence_id: + if verbose: + logger.info(f"All operations were {phrase}") + return True + + if time.monotonic() - begin_time > timeout: + return False except KeyboardInterrupt: if verbose: logger.warning("Waiting interrupted by user") - return + return False - if verbose: - logger.info(f"All operations were {phrase}") + return False - def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None: + def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> bool: """ Waits until all metadata is submitted to Neptune for processing. When submitted, the data is not yet saved in Neptune until fully processed. See wait_for_processing(). + Returns True if all currently queued operations were submitted, False if timeout was reached + or Run is closing. + Args: timeout (float, optional): In seconds, the maximum time to wait for submission. verbose (bool): If True (default), prints messages about the waiting process. """ - self._wait( + return self._wait( phrase="submitted", sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, wait_seq=self._last_queued_seq, @@ -648,17 +707,20 @@ def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = T verbose=verbose, ) - def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None: + def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> bool: """ Waits until all metadata is processed by Neptune. Once the call is complete, the data is saved in Neptune. + Returns True if all currently queued operations were processed, False if timeout was reached + or Run is closing. + Args: timeout (float, optional): In seconds, the maximum time to wait for processing. verbose (bool): If True (default), prints messages about the waiting process. """ - self._wait( + return self._wait( phrase="processed", sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, wait_seq=self._last_ack_seq, diff --git a/src/neptune_scale/sync/errors_tracking.py b/src/neptune_scale/sync/errors_tracking.py index 786f879..81b39a0 100644 --- a/src/neptune_scale/sync/errors_tracking.py +++ b/src/neptune_scale/sync/errors_tracking.py @@ -109,7 +109,10 @@ def __init__( def get_next(self) -> Optional[BaseException]: try: return self._errors_queue.get(block=False) - except queue.Empty: + except (queue.Empty, ValueError): + # Catch ValueError which is raised when reading from an already closed queue. + # This happens sometimes on abnormal termination, so silence the error message. + # TODO: we should synchronize here properly instead return None def work(self) -> None: diff --git a/src/neptune_scale/types.py b/src/neptune_scale/types.py new file mode 100644 index 0000000..40edc7e --- /dev/null +++ b/src/neptune_scale/types.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2025, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + Callable, + Optional, +) + +RunCallback = Callable[[BaseException, Optional[float]], None] diff --git a/src/neptune_scale/util/process_link.py b/src/neptune_scale/util/process_link.py index f7174ca..ef3b9aa 100644 --- a/src/neptune_scale/util/process_link.py +++ b/src/neptune_scale/util/process_link.py @@ -163,6 +163,9 @@ def join(self, timeout: Optional[float] = None) -> None: if not self._worker: raise RuntimeError("You must call start() before calling join()") + if threading.current_thread() == self._worker: + raise RuntimeError("Cannot join() from a callback") + self._worker.join(timeout=timeout) def send(self, message: Any) -> None: diff --git a/tests/unit/test_run_init_and_close.py b/tests/unit/test_run_init_and_close.py new file mode 100644 index 0000000..4a21e27 --- /dev/null +++ b/tests/unit/test_run_init_and_close.py @@ -0,0 +1,121 @@ +import threading +from unittest.mock import patch + +import pytest + +from neptune_scale import Run +from neptune_scale.exceptions import NeptuneUnexpectedError +from neptune_scale.sync.errors_tracking import ErrorsQueue + + +@pytest.fixture +def run(api_token): + return Run(project="workspace/project", api_token=api_token, run_id="test", mode="disabled") + + +@pytest.mark.timeout(10) +def test_multiple_closes_single_thread(run): + """This should not block, hence the timeout check""" + + run.close() + run.close() + + +@pytest.mark.timeout(10) +def test_multiple_closes_multiple_threads(run): + """Close in one thread should block close in another thread""" + + closed = threading.Event() + + def closing_thread(): + # Should block until the first close is done, and return False, as not all operations are done + assert not run.close(), "Run.close() returned True" + assert closed.wait(timeout=1), "wait_for_processing() finished before close()" + + th = threading.Thread(target=closing_thread, daemon=True) + + run.close() + th.start() + closed.set() + + th.join(timeout=1) + + assert not th.is_alive(), "Run.wait_for_processing() did not return in time after close()" + + +@pytest.mark.timeout(10) +def test_wait_for_processing_aborts_if_closed(run): + closed = threading.Event() + + def waiting_thread(): + assert not run.wait_for_processing(timeout=5) + assert closed.wait(timeout=1), "wait_for_processing() finished before close()" + + th = threading.Thread(target=waiting_thread, daemon=True) + + run.close() + th.start() + closed.set() + + th.join(timeout=1) + + assert not th.is_alive(), "Run.wait_for_processing() did not return in time after close()" + + +@pytest.mark.timeout(10) +def test_terminate_on_error(api_token): + """When calling Run.terminate() from the error callback, the run should terminate properly + without deadlocking""" + + callback_called = threading.Event() + callback_finished = threading.Event() + + def callback(exc, ts): + assert isinstance(exc, NeptuneUnexpectedError) + assert "Expected error" in str(exc) + + callback_called.set() + run.terminate() + callback_finished.set() + + run = Run( + project="workspace/project", api_token=api_token, run_id="test", mode="disabled", on_error_callback=callback + ) + + # Pretend we've sent an operation + run._last_queued_seq.value += 1 + run._errors_queue.put(ValueError("Expected error")) + + assert callback_called.wait(timeout=1) + run.wait_for_processing(timeout=1) + assert callback_finished.wait(timeout=10) + + +@pytest.mark.timeout(10) +def test_run_creation_during_initialization_error(api_token): + """If there's an error when creating a Run (with resume=False), the error callback should be called, + and it should be safe to terminate the Run + """ + callback_finished = threading.Event() + + def callback(exc, ts): + run.terminate() + callback_finished.set() + + errors_queue = ErrorsQueue() + + def _create_run(*args, **kwargs): + # This method is called by Run.__init__ to create a run. Instead of submitting a + # CreateRun operation, we simulate an error + errors_queue.put(ValueError("Expected error")) + + with ( + patch("neptune_scale.api.run.ErrorsQueue", return_value=errors_queue), + patch.object(Run, "_create_run", side_effect=_create_run), + ): + run = Run( + project="workspace/project", api_token=api_token, run_id="test", mode="disabled", on_error_callback=callback + ) + + assert callback_finished.wait(timeout=10) + assert run.wait_for_processing(timeout=1)