diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ac066f5..ff08c31e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,5 +29,6 @@ repos: additional_dependencies: - neptune-api==0.4.0 - more-itertools + - backoff default_language_version: python: python3 diff --git a/pyproject.toml b/pyproject.toml index f5fab8f8..a6897830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ python = "^3.8" neptune-api = "0.4.0" more-itertools = "^10.0.0" +psutil = "^5.0.0" +backoff = "^2.0.0" [tool.poetry] name = "neptune-client-scale" @@ -74,6 +76,8 @@ force_grid_wrap = 2 [tool.ruff] line-length = 120 +target-version = "py38" +ignore = ["UP006", "UP007"] [tool.ruff.lint] select = ["F", "UP"] diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 2356e5ac..12964655 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -6,24 +6,40 @@ __all__ = ["Run"] +import atexit +import multiprocessing import os import threading +import time from contextlib import AbstractContextManager from datetime import datetime -from typing import Callable +from multiprocessing.sharedctypes import Synchronized +from multiprocessing.synchronize import Condition as ConditionT +from typing import ( + Callable, + Dict, + List, + Literal, + Optional, + Set, + Union, +) from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation -from neptune_scale.api.api_client import ApiClient from neptune_scale.core.components.abstract import ( Resource, WithResources, ) -from neptune_scale.core.components.errors_monitor import ErrorsMonitor -from neptune_scale.core.components.errors_queue import ErrorsQueue +from neptune_scale.core.components.errors_tracking import ( + ErrorsMonitor, + ErrorsQueue, +) from neptune_scale.core.components.operations_queue import OperationsQueue +from neptune_scale.core.components.sync_process import SyncProcess +from neptune_scale.core.logger import logger from neptune_scale.core.metadata_splitter import MetadataSplitter from neptune_scale.core.serialization import ( datetime_to_proto, @@ -44,6 +60,8 @@ MAX_FAMILY_LENGTH, MAX_QUEUE_SIZE, MAX_RUN_ID_LENGTH, + MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, + STOP_MESSAGE_FREQUENCY, ) @@ -57,15 +75,17 @@ def __init__( *, family: str, run_id: str, - project: str | None = None, - api_token: str | None = None, + project: Optional[str] = None, + api_token: Optional[str] = None, resume: bool = False, - as_experiment: str | None = None, - creation_time: datetime | None = None, - from_run_id: str | None = None, - from_step: int | float | None = None, + mode: Literal["async", "disabled"] = "async", + as_experiment: Optional[str] = None, + creation_time: Optional[datetime] = None, + from_run_id: Optional[str] = None, + from_step: Optional[Union[int, float]] = None, max_queue_size: int = MAX_QUEUE_SIZE, - max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None, + on_network_error_callback: Optional[Callable[[BaseException], None]] = None, ) -> None: """ Initializes a run that logs the model-building metadata to Neptune. @@ -79,15 +99,15 @@ def __init__( api_token: Your Neptune API token. If not provided, the value of the `NEPTUNE_API_TOKEN` environment variable is used. resume: Whether to resume an existing run. + mode: Mode of operation. If set to "disabled", the run doesn't log any metadata. as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run. creation_time: Custom creation time of the run. from_run_id: If forking from an existing run, ID of the run to fork from. from_step: If forking from an existing run, step number to fork from. max_queue_size: Maximum number of operations in a queue. - max_queue_size_exceeded_callback: Callback function triggered when a queue is full. - Accepts two arguments: - - Maximum size of the queue. - - Exception that made the queue full. + max_queue_size_exceeded_callback: Callback function triggered when the queue is full. The function should take the exception + that made the queue full as its argument. + on_network_error_callback: Callback function triggered when a network error occurs. """ verify_type("family", family, str) verify_type("run_id", run_id, str) @@ -143,13 +163,33 @@ def __init__( self._lock = threading.RLock() self._operations_queue: OperationsQueue = OperationsQueue( - lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback + lock=self._lock, + max_size=max_queue_size, ) self._errors_queue: ErrorsQueue = ErrorsQueue() - self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue) - self._backend: ApiClient = ApiClient(api_token=input_api_token) + self._errors_monitor = ErrorsMonitor( + errors_queue=self._errors_queue, + max_queue_size_exceeded_callback=max_queue_size_exceeded_callback, + on_network_error_callback=on_network_error_callback, + ) + self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1) + self._last_put_seq_wait: ConditionT = multiprocessing.Condition() + self._sync_process = SyncProcess( + family=self._family, + operations_queue=self._operations_queue.queue, + errors_queue=self._errors_queue, + api_token=input_api_token, + last_put_seq=self._last_put_seq, + last_put_seq_wait=self._last_put_seq_wait, + max_queue_size=max_queue_size, + mode=mode, + ) self._errors_monitor.start() + with self._lock: + self._sync_process.start() + + self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close) if not resume: self._create_run( @@ -159,32 +199,44 @@ def __init__( from_step=from_step, ) - def __enter__(self) -> Run: - return self - @property def resources(self) -> tuple[Resource, ...]: return ( + self._errors_queue, self._operations_queue, - self._backend, self._errors_monitor, - self._errors_queue, ) + def _close(self) -> None: + # TODO: Change to wait for all operations to be processed + with self._lock: + if self._sync_process.is_alive(): + self.wait_for_submission() + self._sync_process.terminate() + self._sync_process.join() + + self._errors_monitor.interrupt() + self._errors_monitor.join() + + super().close() + def close(self) -> None: """ Stops the connection to Neptune and synchronizes all data. """ - super().close() + if self._exit_func is not None: + atexit.unregister(self._exit_func) + self._exit_func = None + self._close() def _create_run( self, creation_time: datetime, - as_experiment: str | None, - from_run_id: str | None, - from_step: int | float | None, + as_experiment: Optional[str], + from_run_id: Optional[str], + from_step: Optional[Union[int, float]], ) -> None: - fork_point: ForkPoint | None = None + fork_point: Optional[ForkPoint] = None if from_run_id is not None and from_step is not None: fork_point = ForkPoint( parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step) @@ -200,18 +252,16 @@ def _create_run( creation_time=None if creation_time is None else datetime_to_proto(creation_time), ), ) - self._backend.submit(operation=operation, family=self._family) - # TODO: Enqueue on the operations queue - # self._operations_queue.enqueue(operation=operation) + self._operations_queue.enqueue(operation=operation) def log( self, - step: float | int | None = None, - timestamp: datetime | None = None, - fields: dict[str, float | bool | int | str | datetime | list | set] | None = None, - metrics: dict[str, float] | None = None, - add_tags: dict[str, list[str] | set[str]] | None = None, - remove_tags: dict[str, list[str] | set[str]] | None = None, + step: Optional[Union[float, int]] = None, + timestamp: Optional[datetime] = None, + fields: Optional[Dict[str, Union[float, bool, int, str, datetime, list, set]]] = None, + metrics: Optional[Dict[str, float]] = None, + add_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None, + remove_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None, ) -> None: """ Logs the specified metadata to Neptune. @@ -268,6 +318,51 @@ def log( ) for operation in splitter: - self._backend.submit(operation=operation, family=self._family) - # TODO: Enqueue on the operations queue - # self._operations_queue.enqueue(operation=operation) + self._operations_queue.enqueue(operation=operation) + + def wait_for_submission(self, timeout: Optional[float] = None) -> None: + """ + Waits until all metadata is submitted to Neptune. + """ + begin_time = time.time() + logger.info("Waiting for all operations to be processed") + if timeout is None: + logger.warning("No timeout specified. Waiting indefinitely") + + with self._lock: + if not self._sync_process.is_alive(): + logger.warning("Sync process is not running") + return # No need to wait if the sync process is not running + + sleep_time_wait = ( + min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME + ) + last_queued_sequence_id = self._operations_queue.last_sequence_id + last_message_printed: Optional[float] = None + while True: + with self._last_put_seq_wait: + self._last_put_seq_wait.wait(timeout=sleep_time_wait) + value = self._last_put_seq.value + if value == -1: + if self._operations_queue.last_sequence_id != -1: + if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: + last_message_printed = time.time() + logger.info( + "Waiting. No operations processed yet. Operations to sync: %s", + self._operations_queue.last_sequence_id + 1, + ) + else: + if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: + last_message_printed = time.time() + logger.info("Waiting. No operations processed yet") + else: + if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: + last_message_printed = time.time() + logger.info( + "Waiting for remaining %d operation(s) to be synced", + last_queued_sequence_id - value + 1, + ) + if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout): + break + + logger.info("All operations processed") diff --git a/src/neptune_scale/api/api_client.py b/src/neptune_scale/api/api_client.py index 80a15d31..0d8cf301 100644 --- a/src/neptune_scale/api/api_client.py +++ b/src/neptune_scale/api/api_client.py @@ -15,11 +15,16 @@ # from __future__ import annotations -__all__ = ["ApiClient"] - +__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient") +import abc +import os +import uuid from dataclasses import dataclass +from http import HTTPStatus +from typing import Any +from httpx import Timeout from neptune_api import ( AuthenticatedClient, Client, @@ -32,25 +37,14 @@ ClientConfig, Error, ) +from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation +from neptune_api.types import Response from neptune_scale.core.components.abstract import Resource - - -class ApiClient(Resource): - def __init__(self, api_token: str) -> None: - credentials = Credentials.from_api_key(api_key=api_token) - config, token_urls = get_config_and_token_urls(credentials=credentials) - self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls) - - def submit(self, operation: RunOperation, family: str) -> None: - _ = submit_operation.sync(client=self._backend, family=family, body=operation) - - def cleanup(self) -> None: - pass - - def close(self) -> None: - self._backend.__exit__() +from neptune_scale.core.logger import logger +from neptune_scale.envs import ALLOW_SELF_SIGNED_CERTIFICATE +from neptune_scale.parameters import REQUEST_TIMEOUT @dataclass @@ -65,8 +59,15 @@ def from_dict(cls, data: dict) -> TokenRefreshingURLs: ) -def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig, TokenRefreshingURLs]: - with Client(base_url=credentials.base_url) as client: +def get_config_and_token_urls( + *, credentials: Credentials, verify_ssl: bool +) -> tuple[ClientConfig, TokenRefreshingURLs]: + with Client( + base_url=credentials.base_url, + follow_redirects=True, + verify_ssl=verify_ssl, + timeout=Timeout(timeout=REQUEST_TIMEOUT), + ) as client: config = get_client_config.sync(client=client) if config is None or isinstance(config, Error): raise RuntimeError(f"Failed to get client config: {config}") @@ -76,7 +77,7 @@ def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig def create_auth_api_client( - *, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs + *, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs, verify_ssl: bool ) -> AuthenticatedClient: return AuthenticatedClient( base_url=credentials.base_url, @@ -84,4 +85,41 @@ def create_auth_api_client( client_id=config.security.client_id, token_refreshing_endpoint=token_refreshing_urls.token_endpoint, api_key_exchange_callback=exchange_api_key, + follow_redirects=True, + verify_ssl=verify_ssl, + timeout=Timeout(timeout=REQUEST_TIMEOUT), ) + + +class ApiClient(Resource, abc.ABC): + @abc.abstractmethod + def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: ... + + +class HostedApiClient(ApiClient): + def __init__(self, api_token: str) -> None: + credentials = Credentials.from_api_key(api_key=api_token) + + verify_ssl: bool = os.environ.get(ALLOW_SELF_SIGNED_CERTIFICATE, "False").lower() in ("false", "0") + + logger.debug("Trying to connect to Neptune API") + config, token_urls = get_config_and_token_urls(credentials=credentials, verify_ssl=verify_ssl) + self._backend = create_auth_api_client( + credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl + ) + logger.debug("Connected to Neptune API") + + def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: + return submit_operation.sync_detailed(client=self._backend, body=operation, family=family) + + def close(self) -> None: + logger.debug("Closing API client") + self._backend.__exit__() + + +class MockedApiClient(ApiClient): + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: + return Response(content=b"", parsed=RequestId(value=str(uuid.uuid4())), status_code=HTTPStatus.OK, headers={}) diff --git a/src/neptune_scale/core/components/abstract.py b/src/neptune_scale/core/components/abstract.py index 00242fa5..637ade59 100644 --- a/src/neptune_scale/core/components/abstract.py +++ b/src/neptune_scale/core/components/abstract.py @@ -5,6 +5,10 @@ abstractmethod, ) from types import TracebackType +from typing import ( + Optional, + Type, +) class AutoCloseable(ABC): @@ -16,16 +20,16 @@ def close(self) -> None: ... def __exit__( self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], ) -> None: self.close() class Resource(AutoCloseable): - @abstractmethod - def cleanup(self) -> None: ... + def cleanup(self) -> None: + pass def flush(self) -> None: pass diff --git a/src/neptune_scale/core/components/errors_monitor.py b/src/neptune_scale/core/components/errors_monitor.py deleted file mode 100644 index dc9950be..00000000 --- a/src/neptune_scale/core/components/errors_monitor.py +++ /dev/null @@ -1,46 +0,0 @@ -__all__ = ("ErrorsMonitor",) - -import logging -import queue -from typing import Callable - -from neptune_scale.core.components.abstract import Resource -from neptune_scale.core.components.daemon import Daemon -from neptune_scale.core.components.errors_queue import ErrorsQueue - -logger = logging.getLogger("neptune") -logger.setLevel(level=logging.INFO) - - -def on_error(error: BaseException) -> None: - logger.error(error) - - -class ErrorsMonitor(Daemon, Resource): - def __init__( - self, - errors_queue: ErrorsQueue, - on_error_callback: Callable[[BaseException], None] = on_error, - ): - super().__init__(name="ErrorsMonitor", sleep_time=2) - self._errors_queue = errors_queue - self._on_error_callback = on_error_callback - - def work(self) -> None: - try: - error = self._errors_queue.get(block=False) - if error is not None: - self._on_error_callback(error) - except KeyboardInterrupt: - with self._wait_condition: - self._wait_condition.notify_all() - raise - except queue.Empty: - pass - - def cleanup(self) -> None: - pass - - def close(self) -> None: - self.interrupt() - self.join(timeout=10) diff --git a/src/neptune_scale/core/components/errors_queue.py b/src/neptune_scale/core/components/errors_queue.py deleted file mode 100644 index 33bdc38e..00000000 --- a/src/neptune_scale/core/components/errors_queue.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -__all__ = ("ErrorsQueue",) - -from multiprocessing import Queue - -from neptune_scale.core.components.abstract import Resource - - -class ErrorsQueue(Resource): - def __init__(self) -> None: - self._errors_queue: Queue[BaseException] = Queue() - - def put(self, error: BaseException) -> None: - self._errors_queue.put(error) - - def get(self, block: bool = True, timeout: float | None = None) -> BaseException: - return self._errors_queue.get(block=block, timeout=timeout) - - def cleanup(self) -> None: - pass - - def close(self) -> None: - self._errors_queue.close() diff --git a/src/neptune_scale/core/components/errors_tracking.py b/src/neptune_scale/core/components/errors_tracking.py new file mode 100644 index 00000000..18f82aa3 --- /dev/null +++ b/src/neptune_scale/core/components/errors_tracking.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +__all__ = ("ErrorsQueue", "ErrorsMonitor") + +import multiprocessing +import queue +from typing import ( + Callable, + Optional, +) + +from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.components.daemon import Daemon +from neptune_scale.core.logger import logger +from neptune_scale.core.process_killer import kill_me +from neptune_scale.exceptions import ( + NeptuneConnectionLostError, + NeptuneOperationsQueueMaxSizeExceeded, + NeptuneScaleError, + NeptuneUnexpectedError, +) +from neptune_scale.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME + + +class ErrorsQueue(Resource): + def __init__(self) -> None: + self._errors_queue: multiprocessing.Queue[BaseException] = multiprocessing.Queue() + + def put(self, error: BaseException) -> None: + self._errors_queue.put(error) + + def get(self, block: bool = True, timeout: Optional[float] = None) -> BaseException: + return self._errors_queue.get(block=block, timeout=timeout) + + def close(self) -> None: + self._errors_queue.close() + # This is needed to avoid hanging the main process + self._errors_queue.cancel_join_thread() + + +def default_error_callback(error: BaseException) -> None: + logger.error(error) + kill_me() + + +def default_network_error_callback(error: BaseException) -> None: + logger.warning("Experiencing network issues. Retrying...") + + +def default_max_queue_size_exceeded_callback(error: BaseException) -> None: + logger.warning(error) + + +class ErrorsMonitor(Daemon, Resource): + def __init__( + self, + errors_queue: ErrorsQueue, + max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None, + on_network_error_callback: Optional[Callable[[BaseException], None]] = None, + on_error_callback: Optional[Callable[[BaseException], None]] = None, + ): + super().__init__(name="ErrorsMonitor", sleep_time=ERRORS_MONITOR_THREAD_SLEEP_TIME) + + self._errors_queue: ErrorsQueue = errors_queue + self._max_queue_size_exceeded_callback: Callable[[BaseException], None] = ( + max_queue_size_exceeded_callback or default_max_queue_size_exceeded_callback + ) + self._non_network_error_callback: Callable[[BaseException], None] = ( + on_network_error_callback or default_network_error_callback + ) + self._on_error_callback: Callable[[BaseException], None] = on_error_callback or default_error_callback + + def get_next(self) -> Optional[BaseException]: + try: + return self._errors_queue.get(block=False) + except queue.Empty: + return None + + def work(self) -> None: + while (error := self.get_next()) is not None: + if isinstance(error, NeptuneOperationsQueueMaxSizeExceeded): + self._max_queue_size_exceeded_callback(error) + elif isinstance(error, NeptuneConnectionLostError): + self._non_network_error_callback(error) + elif isinstance(error, NeptuneScaleError): + self._on_error_callback(error) + else: + self._on_error_callback(NeptuneUnexpectedError(reason=str(type(error)))) diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py index bed30265..22ec3845 100644 --- a/src/neptune_scale/core/components/operations_queue.py +++ b/src/neptune_scale/core/components/operations_queue.py @@ -4,15 +4,17 @@ from multiprocessing import Queue from time import monotonic -from typing import ( - TYPE_CHECKING, - Callable, - NamedTuple, -) +from typing import TYPE_CHECKING from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.components.queue_element import QueueElement +from neptune_scale.core.logger import logger from neptune_scale.core.validation import verify_type -from neptune_scale.parameters import MAX_QUEUE_ELEMENT_SIZE +from neptune_scale.parameters import ( + MAX_MULTIPROCESSING_QUEUE_SIZE, + MAX_QUEUE_ELEMENT_SIZE, + MAX_QUEUE_SIZE, +) if TYPE_CHECKING: from threading import RLock @@ -20,51 +22,49 @@ from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation -class QueueElement(NamedTuple): - sequence_id: int - occured_at: float - operation: bytes - - -def default_max_size_exceeded_callback(max_size: int, e: BaseException) -> None: - raise ValueError(f"Queue is full (max size: {max_size})") from e - - class OperationsQueue(Resource): def __init__( self, *, lock: RLock, - max_size: int = 0, - max_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + max_size: int = MAX_QUEUE_SIZE, ) -> None: verify_type("max_size", max_size, int) self._lock: RLock = lock self._max_size: int = max_size - self._max_size_exceeded_callback: Callable[[int, BaseException], None] = ( - max_size_exceeded_callback if max_size_exceeded_callback is not None else default_max_size_exceeded_callback - ) self._sequence_id: int = 0 - self._queue: Queue[QueueElement] = Queue(maxsize=max_size) + self._queue: Queue[QueueElement] = Queue(maxsize=min(MAX_MULTIPROCESSING_QUEUE_SIZE, max_size)) + + @property + def queue(self) -> Queue[QueueElement]: + return self._queue + + @property + def last_sequence_id(self) -> int: + with self._lock: + return self._sequence_id - 1 def enqueue(self, *, operation: RunOperation) -> None: try: - # TODO: This lock could be moved to the Run class - with self._lock: - serialized_operation = operation.SerializeToString() + serialized_operation = operation.SerializeToString() - if len(serialized_operation) > MAX_QUEUE_ELEMENT_SIZE: - raise ValueError(f"Operation size exceeds the maximum allowed size ({MAX_QUEUE_ELEMENT_SIZE})") + if len(serialized_operation) > MAX_QUEUE_ELEMENT_SIZE: + raise ValueError(f"Operation size exceeds the maximum allowed size ({MAX_QUEUE_ELEMENT_SIZE})") - self._queue.put_nowait(QueueElement(self._sequence_id, monotonic(), serialized_operation)) + with self._lock: + self._queue.put( + QueueElement(self._sequence_id, monotonic(), serialized_operation), + block=True, + timeout=None, + ) self._sequence_id += 1 except Exception as e: - self._max_size_exceeded_callback(self._max_size, e) - - def cleanup(self) -> None: - pass + logger.error("Failed to enqueue operation: %s %s", e, operation) + raise e def close(self) -> None: self._queue.close() + # This is needed to avoid hanging the main process + self._queue.cancel_join_thread() diff --git a/src/neptune_scale/core/components/queue_element.py b/src/neptune_scale/core/components/queue_element.py new file mode 100644 index 00000000..e736a01e --- /dev/null +++ b/src/neptune_scale/core/components/queue_element.py @@ -0,0 +1,9 @@ +__all__ = ("QueueElement",) + +from typing import NamedTuple + + +class QueueElement(NamedTuple): + sequence_id: int + timestamp: float + operation: bytes diff --git a/src/neptune_scale/core/components/sync_process.py b/src/neptune_scale/core/components/sync_process.py new file mode 100644 index 00000000..34c4802f --- /dev/null +++ b/src/neptune_scale/core/components/sync_process.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +__all__ = ("SyncProcess",) + +import multiprocessing +import queue +from multiprocessing import ( + Process, + Queue, +) +from multiprocessing.sharedctypes import Synchronized +from multiprocessing.synchronize import Condition +from typing import ( + Any, + Callable, + Literal, + Optional, +) + +import backoff +import httpx +from neptune_api.errors import ( + InvalidApiTokenException, + UnableToDeserializeApiKeyError, + UnableToExchangeApiKeyError, + UnableToRefreshTokenError, + UnexpectedStatus, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation +from neptune_api.types import Response + +from neptune_scale.api.api_client import ( + ApiClient, + HostedApiClient, + MockedApiClient, +) +from neptune_scale.core.components.abstract import ( + Resource, + WithResources, +) +from neptune_scale.core.components.daemon import Daemon +from neptune_scale.core.components.errors_tracking import ErrorsQueue +from neptune_scale.core.components.queue_element import QueueElement +from neptune_scale.core.logger import logger +from neptune_scale.exceptions import ( + NeptuneConnectionLostError, + NeptuneInvalidCredentialsError, + NeptuneOperationsQueueMaxSizeExceeded, + NeptuneRetryableError, + NeptuneUnableToAuthenticateError, + NeptuneUnauthorizedError, +) +from neptune_scale.parameters import ( + EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME, + MAX_QUEUE_SIZE, + OPERATION_TIMEOUT, + SHUTDOWN_TIMEOUT, + SYNC_THREAD_SLEEP_TIME, +) + + +def with_api_errors_handling(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return func(*args, **kwargs) + except (InvalidApiTokenException, UnableToDeserializeApiKeyError): + raise NeptuneInvalidCredentialsError() + except (UnableToRefreshTokenError, UnableToExchangeApiKeyError, UnexpectedStatus): + raise NeptuneUnableToAuthenticateError() + except (httpx.ConnectError, httpx.TimeoutException, httpx.RemoteProtocolError): + raise NeptuneConnectionLostError() + except Exception as e: + raise e + + return wrapper + + +class SyncProcess(Process): + def __init__( + self, + operations_queue: Queue, + errors_queue: ErrorsQueue, + api_token: str, + family: str, + mode: Literal["async", "disabled"], + last_put_seq: Synchronized[int], + last_put_seq_wait: Condition, + max_queue_size: int = MAX_QUEUE_SIZE, + ) -> None: + super().__init__(name="SyncProcess") + + self._external_operations_queue: Queue[QueueElement] = operations_queue + self._errors_queue: ErrorsQueue = errors_queue + self._api_token: str = api_token + self._family: str = family + self._last_put_seq: Synchronized[int] = last_put_seq + self._last_put_seq_wait: Condition = last_put_seq_wait + self._max_queue_size: int = max_queue_size + self._mode: Literal["async", "disabled"] = mode + + def run(self) -> None: + logger.info("Data synchronization started") + worker = SyncProcessWorker( + family=self._family, + api_token=self._api_token, + errors_queue=self._errors_queue, + external_operations_queue=self._external_operations_queue, + last_put_seq=self._last_put_seq, + last_put_seq_wait=self._last_put_seq_wait, + max_queue_size=self._max_queue_size, + mode=self._mode, + ) + worker.start() + try: + worker.join() + except KeyboardInterrupt: + worker.interrupt() + worker.wake_up() + worker.join(timeout=SHUTDOWN_TIMEOUT) + worker.close() + + +class SyncProcessWorker(WithResources): + def __init__( + self, + *, + api_token: str, + family: str, + errors_queue: ErrorsQueue, + external_operations_queue: multiprocessing.Queue[QueueElement], + last_put_seq: Synchronized[int], + mode: Literal["async", "disabled"], + last_put_seq_wait: Condition, + max_queue_size: int = MAX_QUEUE_SIZE, + ) -> None: + self._errors_queue = errors_queue + + self._internal_operations_queue: queue.Queue[QueueElement] = queue.Queue(maxsize=max_queue_size) + self._sync_thread = SyncThread( + api_token=api_token, + operations_queue=self._internal_operations_queue, + errors_queue=self._errors_queue, + family=family, + last_put_seq=last_put_seq, + last_put_seq_wait=last_put_seq_wait, + mode=mode, + ) + self._external_to_internal_thread = ExternalToInternalOperationsThread( + external=external_operations_queue, + internal=self._internal_operations_queue, + errors_queue=self._errors_queue, + ) + + @property + def threads(self) -> tuple[Daemon, ...]: + return self._external_to_internal_thread, self._sync_thread + + @property + def resources(self) -> tuple[Resource, ...]: + return self._external_to_internal_thread, self._sync_thread + + def interrupt(self) -> None: + for thread in self.threads: + thread.interrupt() + + def wake_up(self) -> None: + for thread in self.threads: + thread.wake_up() + + def start(self) -> None: + for thread in self.threads: + thread.start() + + def join(self, timeout: Optional[int] = None) -> None: + for thread in self.threads: + thread.join(timeout=timeout) + + +class ExternalToInternalOperationsThread(Daemon, Resource): + def __init__( + self, + external: multiprocessing.Queue[QueueElement], + internal: queue.Queue[QueueElement], + errors_queue: ErrorsQueue, + ) -> None: + super().__init__(name="ExternalToInternalOperationsThread", sleep_time=EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME) + + self._external: multiprocessing.Queue[QueueElement] = external + self._internal: queue.Queue[QueueElement] = internal + self._errors_queue: ErrorsQueue = errors_queue + self._latest_unprocessed: Optional[QueueElement] = None + + def get_next(self) -> Optional[QueueElement]: + if self._latest_unprocessed is not None: + return self._latest_unprocessed + + try: + return self._external.get_nowait() + except queue.Empty: + return None + + def work(self) -> None: + while (operation := self.get_next()) is not None: + logger.debug("Copying operation #%d: %s", operation.sequence_id, operation) + + self._latest_unprocessed = operation + try: + self._internal.put_nowait(operation) + self._latest_unprocessed = None + except queue.Full: + self._errors_queue.put(NeptuneOperationsQueueMaxSizeExceeded(max_size=self._internal.maxsize)) + except Exception as e: + self._errors_queue.put(e) + + +def raise_for_status(response: Response[RequestId]) -> None: + if response.status_code == 403: + raise NeptuneUnauthorizedError() + if response.status_code != 200: + raise RuntimeError(f"Unexpected status code: {response.status_code}") + + +def _ensure_backend_initialized(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: + if mode == "disabled": + return MockedApiClient() + return HostedApiClient(api_token=api_token) + + +class SyncThread(Daemon, WithResources): + def __init__( + self, + api_token: str, + operations_queue: queue.Queue[QueueElement], + errors_queue: ErrorsQueue, + family: str, + last_put_seq: Synchronized[int], + last_put_seq_wait: Condition, + mode: Literal["async", "disabled"], + ) -> None: + super().__init__(name="SyncThread", sleep_time=SYNC_THREAD_SLEEP_TIME) + + self._api_token: str = api_token + self._operations_queue: queue.Queue[QueueElement] = operations_queue + self._errors_queue: ErrorsQueue = errors_queue + self._backend: Optional[ApiClient] = None + self._family: str = family + self._last_put_seq: Synchronized[int] = last_put_seq + self._last_put_seq_wait: Condition = last_put_seq_wait + self._mode: Literal["async", "disabled"] = mode + + self._latest_unprocessed: Optional[QueueElement] = None + + def get_next(self) -> Optional[QueueElement]: + if self._latest_unprocessed is not None: + return self._latest_unprocessed + + try: + return self._operations_queue.get_nowait() + except queue.Empty: + return None + + @property + def resources(self) -> tuple[Resource, ...]: + if self._backend is not None: + return (self._backend,) + return () + + @backoff.on_exception(backoff.expo, NeptuneConnectionLostError, max_time=OPERATION_TIMEOUT) + @with_api_errors_handling + def submit(self, *, operation: RunOperation) -> None: + if self._backend is None: + self._backend = _ensure_backend_initialized(api_token=self._api_token, mode=self._mode) + response = self._backend.submit(operation=operation, family=self._family) + logger.debug("Server response:", response) + raise_for_status(response) + + def work(self) -> None: + while (operation := self.get_next()) is not None: + self._latest_unprocessed = operation + sequence_id, timestamp, data = operation + + try: + run_operation = RunOperation() + run_operation.ParseFromString(data) + self.submit(operation=run_operation) + except NeptuneRetryableError as e: + self._errors_queue.put(e) + continue + except Exception as e: + self._errors_queue.put(e) + self.interrupt() + self._last_put_seq_wait.notify_all() + break + + self._latest_unprocessed = None + + # Update Last PUT sequence id and notify threads in the main process + with self._last_put_seq_wait: + self._last_put_seq.value = sequence_id + self._last_put_seq_wait.notify_all() diff --git a/src/neptune_scale/core/logger.py b/src/neptune_scale/core/logger.py new file mode 100644 index 00000000..9451848d --- /dev/null +++ b/src/neptune_scale/core/logger.py @@ -0,0 +1,38 @@ +__all__ = ("logger",) + +import logging +import os + +from neptune_scale.core.styles import ( + STYLES, + ensure_style_detected, +) +from neptune_scale.envs import DEBUG_MODE + +NEPTUNE_LOGGER_NAME = "neptune" +NEPTUNE_DEBUG_FILE_NAME = "neptune.log" +LOG_FORMAT = "{blue}%(name)s{end} :: {bold}%(levelname)s{end} :: %(message)s" +DEBUG_FORMAT = "%(asctime)s :: %(name)s :: %(levelname)s :: %(message)s" + + +def get_logger() -> logging.Logger: + ensure_style_detected() + + neptune_logger = logging.getLogger(NEPTUNE_LOGGER_NAME) + neptune_logger.setLevel(logging.INFO) + + if os.environ.get(DEBUG_MODE, "False").lower() in ("true", "1"): + file_handler = logging.FileHandler(NEPTUNE_DEBUG_FILE_NAME) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(logging.Formatter(DEBUG_FORMAT)) + neptune_logger.addHandler(file_handler) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(LOG_FORMAT.format(**STYLES))) + neptune_logger.addHandler(stream_handler) + + return neptune_logger + + +logger = get_logger() diff --git a/src/neptune_scale/core/metadata_splitter.py b/src/neptune_scale/core/metadata_splitter.py index 1aba2656..2f77ff13 100644 --- a/src/neptune_scale/core/metadata_splitter.py +++ b/src/neptune_scale/core/metadata_splitter.py @@ -6,8 +6,13 @@ from typing import ( Any, Callable, + Dict, Iterator, + List, + Optional, + Set, TypeVar, + Union, ) from more_itertools import peekable @@ -34,12 +39,12 @@ def __init__( *, project: str, run_id: str, - step: int | float | None, + step: Optional[Union[int, float]], timestamp: datetime, - fields: dict[str, float | bool | int | str | datetime | list | set], - metrics: dict[str, float], - add_tags: dict[str, list[str] | set[str]], - remove_tags: dict[str, list[str] | set[str]], + fields: Dict[str, Union[float, bool, int, str, datetime, list, set]], + metrics: Dict[str, float], + add_tags: Dict[str, Union[List[str], Set[str]]], + remove_tags: Dict[str, Union[List[str], Set[str]]], max_message_bytes_size: int = 1024 * 1024, ): self._step = None if step is None else make_step(number=step) diff --git a/src/neptune_scale/core/process_killer.py b/src/neptune_scale/core/process_killer.py new file mode 100644 index 00000000..5ff15c75 --- /dev/null +++ b/src/neptune_scale/core/process_killer.py @@ -0,0 +1,48 @@ +__all__ = ["kill_me"] + +import os +from typing import List + +import psutil + +from neptune_scale.envs import SUBPROCESS_KILL_TIMEOUT + +KILL_TIMEOUT = int(os.getenv(SUBPROCESS_KILL_TIMEOUT, "5")) + + +def kill_me() -> None: + process = psutil.Process(os.getpid()) + try: + children = _get_process_children(process) + except psutil.NoSuchProcess: + children = [] + + for child_proc in children: + _terminate(child_proc) + _, alive = psutil.wait_procs(children, timeout=KILL_TIMEOUT) + for child_proc in alive: + _kill(child_proc) + # finish with terminating self + _terminate(process) + + +def _terminate(process: psutil.Process) -> None: + try: + process.terminate() + except psutil.NoSuchProcess: + pass + + +def _kill(process: psutil.Process) -> None: + try: + if process.is_running(): + process.kill() + except psutil.NoSuchProcess: + pass + + +def _get_process_children(process: psutil.Process) -> List[psutil.Process]: + try: + return process.children(recursive=True) + except psutil.NoSuchProcess: + return [] diff --git a/src/neptune_scale/core/serialization.py b/src/neptune_scale/core/serialization.py index 0858d8bc..13babac7 100644 --- a/src/neptune_scale/core/serialization.py +++ b/src/neptune_scale/core/serialization.py @@ -8,6 +8,11 @@ ) from datetime import datetime +from typing import ( + List, + Set, + Union, +) from google.protobuf.timestamp_pb2 import Timestamp from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( @@ -17,7 +22,7 @@ ) -def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: +def make_value(value: Union[Value, float, str, int, bool, datetime, List[str], Set[str]]) -> Value: if isinstance(value, Value): return value if isinstance(value, float): @@ -42,7 +47,7 @@ def datetime_to_proto(dt: datetime) -> Timestamp: return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) -def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: +def make_step(number: Union[float, int], raise_on_step_precision_loss: bool = False) -> Step: """ Converts a number to protobuf Step value. Example: >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) diff --git a/src/neptune_scale/core/styles.py b/src/neptune_scale/core/styles.py new file mode 100644 index 00000000..8ecf7b49 --- /dev/null +++ b/src/neptune_scale/core/styles.py @@ -0,0 +1,65 @@ +__all__ = ("STYLES", "ensure_style_detected") + +import os +import platform +from typing import Dict + +from neptune_scale.envs import DISABLE_COLORS + +UNIX_STYLES = { + "h1": "\033[95m", + "h2": "\033[94m", + "blue": "\033[94m", + "python": "\033[96m", + "bash": "\033[95m", + "warning": "\033[93m", + "correct": "\033[92m", + "fail": "\033[91m", + "bold": "\033[1m", + "underline": "\033[4m", + "end": "\033[0m", +} + +WINDOWS_STYLES = { + "h1": "", + "h2": "", + "blue": "", + "python": "", + "bash": "", + "warning": "", + "correct": "", + "fail": "", + "bold": "", + "underline": "", + "end": "", +} + +EMPTY_STYLES = { + "h1": "", + "h2": "", + "blue": "", + "python": "", + "bash": "", + "warning": "", + "correct": "", + "fail": "", + "bold": "", + "underline": "", + "end": "", +} + + +STYLES: Dict[str, str] = {} + + +def ensure_style_detected() -> None: + if not STYLES: + if os.environ.get(DISABLE_COLORS, "False").lower() in ("true", "1"): + STYLES.update(EMPTY_STYLES) + else: + if platform.system() in ["Linux", "Darwin"]: + STYLES.update(UNIX_STYLES) + elif platform.system() == "Windows": + STYLES.update(WINDOWS_STYLES) + else: + STYLES.update(EMPTY_STYLES) diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py index 95a732e1..7fc20952 100644 --- a/src/neptune_scale/core/validation.py +++ b/src/neptune_scale/core/validation.py @@ -8,14 +8,18 @@ "verify_collection_type", ) -from typing import Any +from typing import ( + Any, + Type, + Union, +) -def get_type_name(var_type: type | tuple) -> str: +def get_type_name(var_type: Union[Type, tuple]) -> str: return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) -def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None: +def verify_type(var_name: str, var: Any, expected_type: Union[Type, tuple]) -> None: try: if isinstance(expected_type, tuple): type_name = " or ".join(get_type_name(t) for t in expected_type) @@ -48,7 +52,7 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None: raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") -def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None: +def verify_collection_type(var_name: str, var: Union[list, set, tuple], expected_type: Union[type, tuple]) -> None: verify_type(var_name, var, (list, set, tuple)) for value in var: diff --git a/src/neptune_scale/envs.py b/src/neptune_scale/envs.py index 02681d9b..02e0f9ff 100644 --- a/src/neptune_scale/envs.py +++ b/src/neptune_scale/envs.py @@ -1,3 +1,11 @@ PROJECT_ENV_NAME = "NEPTUNE_PROJECT" API_TOKEN_ENV_NAME = "NEPTUNE_API_TOKEN" + +DISABLE_COLORS = "NEPTUNE_DISABLE_COLORS" + +DEBUG_MODE = "NEPTUNE_DEBUG_MODE" + +SUBPROCESS_KILL_TIMEOUT = "NEPTUNE_SUBPROCESS_KILL_TIMEOUT" + +ALLOW_SELF_SIGNED_CERTIFICATE = "NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE" diff --git a/src/neptune_scale/exceptions.py b/src/neptune_scale/exceptions.py new file mode 100644 index 00000000..b43bd290 --- /dev/null +++ b/src/neptune_scale/exceptions.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +__all__ = ( + "NeptuneScaleError", + "NeptuneOperationsQueueMaxSizeExceeded", + "NeptuneUnauthorizedError", + "NeptuneInvalidCredentialsError", + "NeptuneUnexpectedError", + "NeptuneConnectionLostError", + "NeptuneUnableToAuthenticateError", + "NeptuneRetryableError", +) + +from typing import Any + +from neptune_scale.core.styles import ( + STYLES, + ensure_style_detected, +) + + +class NeptuneScaleError(Exception): + message = "An error occurred in the Neptune Scale client." + + def __init__(self, *args: Any, **kwargs: Any) -> None: + ensure_style_detected() + super().__init__(self.message.format(*args, **STYLES, **kwargs)) + + +class NeptuneOperationsQueueMaxSizeExceeded(NeptuneScaleError): + message = """ +{h1} +----NeptuneOperationsQueueMaxSizeExceeded-------------------------------------- +{end} +The queue size for internal operations was exceeded (max allowed: {max_size}) because too much data was queued in a short time. + +The synchronization is paused until the queue size drops below the maximum. + +To resolve this issue, consider the following: + - Reduce the frequency of data being sent to the queue, or throttle the rate of operations. + - Cautiously increase the queue size through the `max_queue_size` argument. + Note: To ensure that memory usage remains within acceptable limits, + closely monitor your system's memory consumption. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneUnauthorizedError(NeptuneScaleError): + message = """ +{h1} +----NeptuneUnauthorizedError--------------------------------------------------- +{end} +You don't have permission to access the given resource. + + - Verify that your API token is correct. To find your API token: + - Log in to Neptune Scale and open the user menu. + - If your workspace uses service accounts, ask the project owner to provide the token. + + - Verify that the provided project name is correct. + The correct project name should look like this: {correct}WORKSPACE_NAME/PROJECT_NAME{end} + It has two parts: + - {correct}WORKSPACE_NAME{end}: can be your username or your organization name + - {correct}PROJECT_NAME{end}: the name specified for the project + + - Ask your workspace admin to grant you the necessary privileges to the project. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneInvalidCredentialsError(NeptuneScaleError): + message = """ +{h1} +----NeptuneInvalidCredentialsError--------------------------------------------- +{end} +The provided API token is invalid. +Make sure you copied your API token while logged in to Neptune Scale. +If your workspace uses service accounts, ask the project owner for the token. + +There are two options to provide the API token: + - Set it as an environment variable in your operating system + - Paste it into your Python code (not recommended) + +{h2}Environment variable{end} {correct}(Recommended){end} +Set the NEPTUNE_API_TOKEN environment variable depending on your operating system: + + {correct}Linux/Unix{end} + In the terminal: + {bash}export NEPTUNE_API_TOKEN="YOUR_API_TOKEN"{end} + + {correct}Windows{end} + In Command Prompt or similar: + {bash}setx NEPTUNE_API_TOKEN "YOUR_API_TOKEN"{end} + +and omit the {bold}api_token{end} argument from the {bold}Run{end} constructor: + {python}neptune_scale.Run(project="WORKSPACE_NAME/PROJECT_NAME"){end} + +{h2}Option 2: Run argument{end} +Pass the token to the {bold}Run{end} constructor via the {bold}api_token{end} argument: + {python}neptune_scale.Run(project="WORKSPACE_NAME/PROJECT_NAME", api_token="YOUR_API_TOKEN"){end} + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneUnexpectedError(NeptuneScaleError): + message = """ +{h1} +----NeptuneUnexpectedError----------------------------------------------------- +{end} +An unexpected error occurred in the Neptune Scale client. For help, contact support@neptune.ai. Raw exception name: "{reason}" + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + def __init__(self, reason: str) -> None: + super().__init__(reason=reason) + + +class NeptuneRetryableError(NeptuneScaleError): + pass + + +class NeptuneConnectionLostError(NeptuneRetryableError): + message = """ +{h1} +----NeptuneConnectionLostError------------------------------------------------- +{end} +The connection to the Neptune server was lost. Ensure that your computer is connected to the internet and that + firewall settings aren't blocking the connection. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneUnableToAuthenticateError(NeptuneScaleError): + message = """ +{h1} +----NeptuneUnableToAuthenticateError------------------------------------------- +{end} +The client was unable to authenticate with the Neptune server. Ensure that your API token is correct. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" diff --git a/src/neptune_scale/parameters.py b/src/neptune_scale/parameters.py index 44112374..3c7d65b4 100644 --- a/src/neptune_scale/parameters.py +++ b/src/neptune_scale/parameters.py @@ -1,4 +1,13 @@ MAX_RUN_ID_LENGTH = 128 MAX_FAMILY_LENGTH = 128 -MAX_QUEUE_SIZE = 32767 +MAX_QUEUE_SIZE = 1000000 +MAX_MULTIPROCESSING_QUEUE_SIZE = 32767 MAX_QUEUE_ELEMENT_SIZE = 1024 * 1024 # 1MB +SYNC_THREAD_SLEEP_TIME = 0.1 +EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME = 0.1 +ERRORS_MONITOR_THREAD_SLEEP_TIME = 0.1 +SHUTDOWN_TIMEOUT = 60 # 1 minute +MINIMAL_WAIT_FOR_PUT_SLEEP_TIME = 10 +STOP_MESSAGE_FREQUENCY = 5 +REQUEST_TIMEOUT = 5 +OPERATION_TIMEOUT = 60 diff --git a/tests/unit/test_errors_monitor.py b/tests/unit/test_errors_monitor.py index e4352d7e..157b5c3e 100644 --- a/tests/unit/test_errors_monitor.py +++ b/tests/unit/test_errors_monitor.py @@ -1,7 +1,9 @@ from unittest.mock import Mock -from neptune_scale.core.components.errors_monitor import ErrorsMonitor -from neptune_scale.core.components.errors_queue import ErrorsQueue +from neptune_scale.core.components.errors_tracking import ( + ErrorsMonitor, + ErrorsQueue, +) def test_errors_monitor(): @@ -14,9 +16,13 @@ def test_errors_monitor(): # when errors_queue.put(ValueError("error1")) + errors_queue.flush() + + # and errors_monitor.start() - errors_monitor.interrupt() - errors_monitor.join(timeout=1) + errors_monitor.work() + errors_monitor.wake_up() + errors_monitor.join(timeout=5) # then callback.assert_called() diff --git a/tests/unit/test_operations_queue.py b/tests/unit/test_operations_queue.py index f7c4d59a..4ececb4d 100644 --- a/tests/unit/test_operations_queue.py +++ b/tests/unit/test_operations_queue.py @@ -1,5 +1,4 @@ import threading -from unittest.mock import MagicMock import pytest from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( @@ -32,23 +31,6 @@ def test__enqueue(): assert queue._sequence_id == 2 -def test__max_queue_size_exceeded(): - # given - lock = threading.RLock() - callback = MagicMock() - queue = OperationsQueue(lock=lock, max_size=1, max_size_exceeded_callback=callback) - - # and - operation = RunOperation() - - # when - queue.enqueue(operation=operation) - queue.enqueue(operation=operation) - - # then - callback.assert_called_once() - - def test__max_element_size_exceeded(): # given lock = threading.RLock() diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index a3cb8dc9..78e26089 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -2,7 +2,6 @@ import json import uuid from datetime import datetime -from unittest.mock import patch import pytest from freezegun import freeze_time @@ -15,21 +14,6 @@ def api_token(): return base64.b64encode(json.dumps({"api_address": "aa", "api_url": "bb"}).encode("utf-8")).decode("utf-8") -class MockedApiClient: - def __init__(self, *args, **kwargs) -> None: - pass - - def submit(self, operation, family) -> None: - pass - - def close(self) -> None: - pass - - def cleanup(self) -> None: - pass - - -@patch("neptune_scale.ApiClient", MockedApiClient) def test_context_manager(api_token): # given project = "workspace/project" @@ -37,14 +21,13 @@ def test_context_manager(api_token): family = run_id # when - with Run(project=project, api_token=api_token, family=family, run_id=run_id): + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled"): ... # then assert True -@patch("neptune_scale.ApiClient", MockedApiClient) def test_close(api_token): # given project = "workspace/project" @@ -52,7 +35,7 @@ def test_close(api_token): family = run_id # and - run = Run(project=project, api_token=api_token, family=family, run_id=run_id) + run = Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled") # when run.close() @@ -61,7 +44,6 @@ def test_close(api_token): assert True -@patch("neptune_scale.ApiClient", MockedApiClient) def test_family_too_long(api_token): # given project = "workspace/project" @@ -72,11 +54,13 @@ def test_family_too_long(api_token): # when with pytest.raises(ValueError): - with Run(project=project, api_token=api_token, family=family, run_id=run_id): + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled"): ... + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_run_id_too_long(api_token): # given project = "workspace/project" @@ -87,11 +71,13 @@ def test_run_id_too_long(api_token): # then with pytest.raises(ValueError): - with Run(project=project, api_token=api_token, family=family, run_id=run_id): + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled"): ... + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_invalid_project_name(api_token): # given run_id = str(uuid.uuid4()) @@ -102,11 +88,13 @@ def test_invalid_project_name(api_token): # then with pytest.raises(ValueError): - with Run(project=project, api_token=api_token, family=family, run_id=run_id): + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled"): ... + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_metadata(api_token): # given project = "workspace/project" @@ -114,7 +102,7 @@ def test_metadata(api_token): family = run_id # then - with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled") as run: run.log( step=1, timestamp=datetime.now(), @@ -136,8 +124,10 @@ def test_metadata(api_token): }, ) + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_log_without_step(api_token): # given project = "workspace/project" @@ -145,7 +135,7 @@ def test_log_without_step(api_token): family = run_id # then - with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled") as run: run.log( timestamp=datetime.now(), fields={ @@ -153,8 +143,10 @@ def test_log_without_step(api_token): }, ) + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_log_step_float(api_token): # given project = "workspace/project" @@ -162,7 +154,7 @@ def test_log_step_float(api_token): family = run_id # then - with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled") as run: run.log( step=3.14, timestamp=datetime.now(), @@ -171,8 +163,10 @@ def test_log_step_float(api_token): }, ) + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_log_no_timestamp(api_token): # given project = "workspace/project" @@ -180,7 +174,7 @@ def test_log_no_timestamp(api_token): family = run_id # then - with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + with Run(project=project, api_token=api_token, family=family, run_id=run_id, mode="disabled") as run: run.log( step=3.14, fields={ @@ -188,8 +182,10 @@ def test_log_no_timestamp(api_token): }, ) + # and + assert True + -@patch("neptune_scale.ApiClient", MockedApiClient) def test_resume(api_token): # given project = "workspace/project" @@ -197,14 +193,13 @@ def test_resume(api_token): family = run_id # when - with Run(project=project, api_token=api_token, family=family, run_id=run_id, resume=True): + with Run(project=project, api_token=api_token, family=family, run_id=run_id, resume=True, mode="disabled"): ... # then assert True -@patch("neptune_scale.ApiClient", MockedApiClient) @freeze_time("2024-07-30 12:12:12.000022") def test_creation_time(api_token): # given @@ -213,14 +208,20 @@ def test_creation_time(api_token): family = run_id # when - with Run(project=project, api_token=api_token, family=family, run_id=run_id, creation_time=datetime.now()): + with Run( + project=project, + api_token=api_token, + family=family, + run_id=run_id, + creation_time=datetime.now(), + mode="disabled", + ): ... # then assert True -@patch("neptune_scale.ApiClient", MockedApiClient) def test_assign_experiment(api_token): # given project = "workspace/project" @@ -228,14 +229,20 @@ def test_assign_experiment(api_token): family = run_id # when - with Run(project=project, api_token=api_token, family=family, run_id=run_id, as_experiment="experiment_id"): + with Run( + project=project, + api_token=api_token, + family=family, + run_id=run_id, + as_experiment="experiment_id", + mode="disabled", + ): ... # then assert True -@patch("neptune_scale.ApiClient", MockedApiClient) def test_forking(api_token): # given project = "workspace/project" @@ -244,7 +251,13 @@ def test_forking(api_token): # when with Run( - project=project, api_token=api_token, family=family, run_id=run_id, from_run_id="parent-run-id", from_step=3.14 + project=project, + api_token=api_token, + family=family, + run_id=run_id, + from_run_id="parent-run-id", + from_step=3.14, + mode="disabled", ): ...