diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b421be2a..73ef26dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.7.0b + - neptune-api - more-itertools - backoff default_language_version: diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index a6cf90c2..1bbcccfd 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -27,9 +27,11 @@ from typing import ( Any, Literal, + cast, ) import httpx +import neptune_retrieval_api.client from httpx import Timeout from neptune_api import ( AuthenticatedClient, @@ -64,6 +66,11 @@ from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus from neptune_api.types import Response +from neptune_retrieval_api.api.default import search_leaderboard_entries_proto +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO +from neptune_retrieval_api.proto.neptune_pb.api.v1.model.leaderboard_entries_pb2 import ( + ProtoLeaderboardEntriesSearchResultDTO, +) from neptune_scale.exceptions import ( NeptuneConnectionLostError, @@ -129,6 +136,11 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons @abc.abstractmethod def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + @abc.abstractmethod + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: @@ -141,6 +153,9 @@ def __init__(self, api_token: str) -> None: self.backend = create_auth_api_client( credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl ) + # This is required only to silence mypy. The two client objects are compatible, because they're + # generated by swagger codegen. + self.retrieval_backend = cast(neptune_retrieval_api.client.AuthenticatedClient, self.backend) logger.debug("Connected to Neptune API") def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]: @@ -153,6 +168,15 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]), ) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + resp = search_leaderboard_entries_proto.sync_detailed( + client=self.retrieval_backend, project_identifier=project_id, type=["run"], body=body + ) + result = ProtoLeaderboardEntriesSearchResultDTO.FromString(resp.content) + return result + def close(self) -> None: logger.debug("Closing API client") self.backend.__exit__() @@ -181,6 +205,11 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ ) return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={}) + def search_entries( + self, project_id: str, body: SearchLeaderboardEntriesParamsDTO + ) -> ProtoLeaderboardEntriesSearchResultDTO: + return ProtoLeaderboardEntriesSearchResultDTO() + def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: if mode == "disabled": diff --git a/src/neptune_scale/net/projects.py b/src/neptune_scale/net/projects.py index ce698ada..4a97de2d 100644 --- a/src/neptune_scale/net/projects.py +++ b/src/neptune_scale/net/projects.py @@ -1,4 +1,3 @@ -import os import re from enum import Enum from json import JSONDecodeError @@ -11,7 +10,6 @@ import httpx from neptune_scale.exceptions import ( - NeptuneApiTokenNotProvided, NeptuneBadRequestError, NeptuneProjectAlreadyExists, ) @@ -19,7 +17,7 @@ HostedApiClient, with_api_errors_handling, ) -from neptune_scale.util.envs import API_TOKEN_ENV_NAME +from neptune_scale.sync.util import ensure_api_token PROJECTS_PATH_BASE = "/api/backend/v1/projects" @@ -33,14 +31,6 @@ class ProjectVisibility(Enum): ORGANIZATION_NOT_FOUND_RE = re.compile(r"Organization .* not found") -def _get_api_token(api_token: Optional[str]) -> str: - api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) - if api_token is None: - raise NeptuneApiTokenNotProvided() - - return api_token - - @with_api_errors_handling def create_project( workspace: str, @@ -52,9 +42,7 @@ def create_project( fail_if_exists: bool = False, api_token: Optional[str] = None, ) -> None: - api_token = _get_api_token(api_token) - - client = HostedApiClient(api_token=api_token) + client = HostedApiClient(api_token=ensure_api_token(api_token)) visibility = ProjectVisibility(visibility) body = { @@ -92,7 +80,7 @@ def _safe_json(response: httpx.Response) -> Any: def get_project_list(*, api_token: Optional[str] = None) -> list[dict]: - client = HostedApiClient(api_token=_get_api_token(api_token)) + client = HostedApiClient(api_token=ensure_api_token(api_token)) params = { "userRelation": "viewerOrHigher", diff --git a/src/neptune_scale/net/runs.py b/src/neptune_scale/net/runs.py new file mode 100644 index 00000000..bf14fe86 --- /dev/null +++ b/src/neptune_scale/net/runs.py @@ -0,0 +1,29 @@ +from typing import Optional + +from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO + +from neptune_scale.exceptions import NeptuneScaleError +from neptune_scale.net.api_client import HostedApiClient +from neptune_scale.net.util import escape_nql_criterion +from neptune_scale.sync.util import ensure_api_token + + +def run_exists(project: str, run_id: str, api_token: Optional[str] = None) -> bool: + """Query the backend for the existence of a Run with the given ID. + + Returns True if the Run exists, False otherwise. + """ + + client = HostedApiClient(api_token=ensure_api_token(api_token)) + body = SearchLeaderboardEntriesParamsDTO.from_dict( + { + "query": {"query": f'`sys/custom_run_id`:string = "{escape_nql_criterion(run_id)}"'}, + } + ) + + try: + result = client.search_entries(project, body) + except Exception as e: + raise NeptuneScaleError(reason=e) + + return bool(result.entries) diff --git a/src/neptune_scale/net/util.py b/src/neptune_scale/net/util.py new file mode 100644 index 00000000..cbc25d2e --- /dev/null +++ b/src/neptune_scale/net/util.py @@ -0,0 +1,6 @@ +def escape_nql_criterion(criterion: str) -> str: + """ + Escape backslash and (double-)quotes in the string, to match what the NQL engine expects. + """ + + return criterion.replace("\\", r"\\").replace('"', r"\"") diff --git a/src/neptune_scale/sync/util.py b/src/neptune_scale/sync/util.py index 60fe4b0b..2d5ecf96 100644 --- a/src/neptune_scale/sync/util.py +++ b/src/neptune_scale/sync/util.py @@ -1,4 +1,9 @@ +import os import signal +from typing import Optional + +from neptune_scale.exceptions import NeptuneApiTokenNotProvided +from neptune_scale.util.envs import API_TOKEN_ENV_NAME def safe_signal_name(signum: int) -> str: @@ -8,3 +13,13 @@ def safe_signal_name(signum: int) -> str: signame = str(signum) return signame + + +def ensure_api_token(api_token: Optional[str]) -> str: + """Ensure the API token is provided via either explicit argument, or env variable.""" + + api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) + if api_token is None: + raise NeptuneApiTokenNotProvided() + + return api_token diff --git a/tests/e2e/test_net.py b/tests/e2e/test_net.py new file mode 100644 index 00000000..263def20 --- /dev/null +++ b/tests/e2e/test_net.py @@ -0,0 +1,10 @@ +import os + +from neptune_scale.net.runs import run_exists + +NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT") + + +def test_run_exists_true(run): + assert run_exists(run._project, run._run_id) + assert not run_exists(run._project, "nonexistent_run_id") diff --git a/tests/unit/test_process_link.py b/tests/unit/test_process_link.py index f3f68ec3..c030a268 100644 --- a/tests/unit/test_process_link.py +++ b/tests/unit/test_process_link.py @@ -169,7 +169,7 @@ def on_closed(_): link.start(on_link_closed=on_closed) # We should never finish the sleep call, as on_closed raises SystemExit - time.sleep(5) + time.sleep(10) assert False, "on_closed callback was not called" @@ -184,5 +184,5 @@ def test_parent_termination(): p = multiprocessing.Process(target=parent, args=(var, event)) p.start() - assert event.wait(1) + assert event.wait(5) assert var.value == 1