Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .net.runs.run_exists() #119

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
args: [ --config-file, pyproject.toml ]
pass_filenames: false
additional_dependencies:
- neptune-api==0.7.0b
- neptune-api
- more-itertools
- backoff
default_language_version:
Expand Down
29 changes: 29 additions & 0 deletions src/neptune_scale/net/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from typing import (
Any,
Literal,
cast,
)

import httpx
import neptune_retrieval_api.client
from httpx import Timeout
from neptune_api import (
AuthenticatedClient,
Expand Down Expand Up @@ -64,6 +66,11 @@
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus
from neptune_api.types import Response
from neptune_retrieval_api.api.default import search_leaderboard_entries_proto
from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO
from neptune_retrieval_api.proto.neptune_pb.api.v1.model.leaderboard_entries_pb2 import (
ProtoLeaderboardEntriesSearchResultDTO,
)

from neptune_scale.exceptions import (
NeptuneConnectionLostError,
Expand Down Expand Up @@ -129,6 +136,11 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons
@abc.abstractmethod
def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ...

@abc.abstractmethod
def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO: ...


class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
Expand All @@ -141,6 +153,9 @@ def __init__(self, api_token: str) -> None:
self.backend = create_auth_api_client(
credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl
)
# This is required only to silence mypy. The two client objects are compatible, because they're
# generated by swagger codegen.
self.retrieval_backend = cast(neptune_retrieval_api.client.AuthenticatedClient, self.backend)
logger.debug("Connected to Neptune API")

def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]:
Expand All @@ -153,6 +168,15 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ
body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]),
)

def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO:
resp = search_leaderboard_entries_proto.sync_detailed(
client=self.retrieval_backend, project_identifier=project_id, type=["run"], body=body
)
result = ProtoLeaderboardEntriesSearchResultDTO.FromString(resp.content)
return result

def close(self) -> None:
logger.debug("Closing API client")
self.backend.__exit__()
Expand Down Expand Up @@ -181,6 +205,11 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ
)
return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={})

def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO:
return ProtoLeaderboardEntriesSearchResultDTO()


def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient:
if mode == "disabled":
Expand Down
18 changes: 3 additions & 15 deletions src/neptune_scale/net/projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
from enum import Enum
from json import JSONDecodeError
Expand All @@ -11,15 +10,14 @@
import httpx

from neptune_scale.exceptions import (
NeptuneApiTokenNotProvided,
NeptuneBadRequestError,
NeptuneProjectAlreadyExists,
)
from neptune_scale.net.api_client import (
HostedApiClient,
with_api_errors_handling,
)
from neptune_scale.util.envs import API_TOKEN_ENV_NAME
from neptune_scale.sync.util import ensure_api_token

PROJECTS_PATH_BASE = "/api/backend/v1/projects"

Expand All @@ -33,14 +31,6 @@ class ProjectVisibility(Enum):
ORGANIZATION_NOT_FOUND_RE = re.compile(r"Organization .* not found")


def _get_api_token(api_token: Optional[str]) -> str:
api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME)
if api_token is None:
raise NeptuneApiTokenNotProvided()

return api_token


@with_api_errors_handling
def create_project(
workspace: str,
Expand All @@ -52,9 +42,7 @@ def create_project(
fail_if_exists: bool = False,
api_token: Optional[str] = None,
) -> None:
api_token = _get_api_token(api_token)

client = HostedApiClient(api_token=api_token)
client = HostedApiClient(api_token=ensure_api_token(api_token))
visibility = ProjectVisibility(visibility)

body = {
Expand Down Expand Up @@ -92,7 +80,7 @@ def _safe_json(response: httpx.Response) -> Any:


def get_project_list(*, api_token: Optional[str] = None) -> list[dict]:
client = HostedApiClient(api_token=_get_api_token(api_token))
client = HostedApiClient(api_token=ensure_api_token(api_token))

params = {
"userRelation": "viewerOrHigher",
Expand Down
29 changes: 29 additions & 0 deletions src/neptune_scale/net/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional

from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO

from neptune_scale.exceptions import NeptuneScaleError
from neptune_scale.net.api_client import HostedApiClient
from neptune_scale.net.util import escape_nql_criterion
from neptune_scale.sync.util import ensure_api_token


def run_exists(project: str, run_id: str, api_token: Optional[str] = None) -> bool:
"""Query the backend for the existence of a Run with the given ID.

Returns True if the Run exists, False otherwise.
"""

client = HostedApiClient(api_token=ensure_api_token(api_token))
body = SearchLeaderboardEntriesParamsDTO.from_dict(
{
"query": {"query": f'`sys/custom_run_id`:string = "{escape_nql_criterion(run_id)}"'},
}
)

try:
result = client.search_entries(project, body)
except Exception as e:
raise NeptuneScaleError(reason=e)

return bool(result.entries)
6 changes: 6 additions & 0 deletions src/neptune_scale/net/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def escape_nql_criterion(criterion: str) -> str:
"""
Escape backslash and (double-)quotes in the string, to match what the NQL engine expects.
"""

return criterion.replace("\\", r"\\").replace('"', r"\"")
kgodlewski marked this conversation as resolved.
Show resolved Hide resolved
15 changes: 15 additions & 0 deletions src/neptune_scale/sync/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os
import signal
from typing import Optional

from neptune_scale.exceptions import NeptuneApiTokenNotProvided
from neptune_scale.util.envs import API_TOKEN_ENV_NAME


def safe_signal_name(signum: int) -> str:
Expand All @@ -8,3 +13,13 @@ def safe_signal_name(signum: int) -> str:
signame = str(signum)

return signame


def ensure_api_token(api_token: Optional[str]) -> str:
"""Ensure the API token is provided via either explicit argument, or env variable."""

api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME)
if api_token is None:
raise NeptuneApiTokenNotProvided()

return api_token
10 changes: 10 additions & 0 deletions tests/e2e/test_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

from neptune_scale.net.runs import run_exists

NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT")


def test_run_exists_true(run):
assert run_exists(run._project, run._run_id)
assert not run_exists(run._project, "nonexistent_run_id")
4 changes: 2 additions & 2 deletions tests/unit/test_process_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def on_closed(_):

link.start(on_link_closed=on_closed)
# We should never finish the sleep call, as on_closed raises SystemExit
time.sleep(5)
time.sleep(10)
assert False, "on_closed callback was not called"


Expand All @@ -184,5 +184,5 @@ def test_parent_termination():
p = multiprocessing.Process(target=parent, args=(var, event))
p.start()

assert event.wait(1)
assert event.wait(5)
assert var.value == 1
Loading