Skip to content

Commit

Permalink
Merge branch 'main' into kg/run-error-handling-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kgodlewski authored Jan 13, 2025
2 parents e0923f1 + 4c2e36e commit b6e03cd
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
args: [ --config-file, pyproject.toml ]
pass_filenames: false
additional_dependencies:
- neptune-api==0.7.0b
- neptune-api
- more-itertools
- backoff
default_language_version:
Expand Down
54 changes: 42 additions & 12 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import threading
import warnings
from collections.abc import (
Collection,
Expand All @@ -14,6 +15,7 @@
cast,
)

from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue

Expand Down Expand Up @@ -59,6 +61,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue)
self._run_id = run_id
self._operations_queue = operations_queue
self._attributes: dict[str, Attribute] = {}
# Keep a list of path -> (last step, last value) mappings to detect non-increasing steps
# at call site. The backend will detect this error as well, but it's more convenient for the user
# to get the error as soon as possible.
self._metric_state: dict[str, tuple[float, float]] = {}
self._lock = threading.RLock()

def __getitem__(self, path: str) -> "Attribute":
path = cleanup_path(path)
Expand All @@ -85,22 +92,45 @@ def log(
) -> None:
if timestamp is None:
timestamp = datetime.now()
elif isinstance(timestamp, float):
elif isinstance(timestamp, (float, int)):
timestamp = datetime.fromtimestamp(timestamp)

splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
# MetadataSplitter is an iterator, so gather everything into a list instead of iterating over
# it in the critical section, to avoid holding the lock for too long.
# TODO: Move splitting into the worker process. Here we should just send messages as they are.
chunks = list(
MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
)
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step)
with self._lock:
self._verify_and_update_metrics_state(step, metrics)

for operation, metadata_size in chunks:
self._operations_queue.enqueue(operation=operation, size=metadata_size)

def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[dict[str, float]]) -> None:
"""Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not."""

if step is None or metrics is None:
return

for metric, value in metrics.items():
if (state := self._metric_state.get(metric)) is not None:
last_step, last_value = state
# Repeating a step is fine as long as the value does not change
if step < last_step or (step == last_step and value != last_value):
raise NeptuneSeriesStepNonIncreasing()

self._metric_state[metric] = (step, value)


class Attribute:
Expand Down
29 changes: 29 additions & 0 deletions src/neptune_scale/net/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from typing import (
Any,
Literal,
cast,
)

import httpx
import neptune_retrieval_api.client
from httpx import Timeout
from neptune_api import (
AuthenticatedClient,
Expand Down Expand Up @@ -64,6 +66,11 @@
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus
from neptune_api.types import Response
from neptune_retrieval_api.api.default import search_leaderboard_entries_proto
from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO
from neptune_retrieval_api.proto.neptune_pb.api.v1.model.leaderboard_entries_pb2 import (
ProtoLeaderboardEntriesSearchResultDTO,
)

from neptune_scale.exceptions import (
NeptuneConnectionLostError,
Expand Down Expand Up @@ -129,6 +136,11 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons
@abc.abstractmethod
def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ...

@abc.abstractmethod
def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO: ...


class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
Expand All @@ -141,6 +153,9 @@ def __init__(self, api_token: str) -> None:
self.backend = create_auth_api_client(
credentials=credentials, config=config, token_refreshing_urls=token_urls, verify_ssl=verify_ssl
)
# This is required only to silence mypy. The two client objects are compatible, because they're
# generated by swagger codegen.
self.retrieval_backend = cast(neptune_retrieval_api.client.AuthenticatedClient, self.backend)
logger.debug("Connected to Neptune API")

def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]:
Expand All @@ -153,6 +168,15 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ
body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]),
)

def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO:
resp = search_leaderboard_entries_proto.sync_detailed(
client=self.retrieval_backend, project_identifier=project_id, type=["run"], body=body
)
result = ProtoLeaderboardEntriesSearchResultDTO.FromString(resp.content)
return result

def close(self) -> None:
logger.debug("Closing API client")
self.backend.__exit__()
Expand Down Expand Up @@ -181,6 +205,11 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ
)
return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={})

def search_entries(
self, project_id: str, body: SearchLeaderboardEntriesParamsDTO
) -> ProtoLeaderboardEntriesSearchResultDTO:
return ProtoLeaderboardEntriesSearchResultDTO()


def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient:
if mode == "disabled":
Expand Down
18 changes: 3 additions & 15 deletions src/neptune_scale/net/projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
from enum import Enum
from json import JSONDecodeError
Expand All @@ -11,15 +10,14 @@
import httpx

from neptune_scale.exceptions import (
NeptuneApiTokenNotProvided,
NeptuneBadRequestError,
NeptuneProjectAlreadyExists,
)
from neptune_scale.net.api_client import (
HostedApiClient,
with_api_errors_handling,
)
from neptune_scale.util.envs import API_TOKEN_ENV_NAME
from neptune_scale.sync.util import ensure_api_token

PROJECTS_PATH_BASE = "/api/backend/v1/projects"

Expand All @@ -33,14 +31,6 @@ class ProjectVisibility(Enum):
ORGANIZATION_NOT_FOUND_RE = re.compile(r"Organization .* not found")


def _get_api_token(api_token: Optional[str]) -> str:
api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME)
if api_token is None:
raise NeptuneApiTokenNotProvided()

return api_token


@with_api_errors_handling
def create_project(
workspace: str,
Expand All @@ -52,9 +42,7 @@ def create_project(
fail_if_exists: bool = False,
api_token: Optional[str] = None,
) -> None:
api_token = _get_api_token(api_token)

client = HostedApiClient(api_token=api_token)
client = HostedApiClient(api_token=ensure_api_token(api_token))
visibility = ProjectVisibility(visibility)

body = {
Expand Down Expand Up @@ -92,7 +80,7 @@ def _safe_json(response: httpx.Response) -> Any:


def get_project_list(*, api_token: Optional[str] = None) -> list[dict]:
client = HostedApiClient(api_token=_get_api_token(api_token))
client = HostedApiClient(api_token=ensure_api_token(api_token))

params = {
"userRelation": "viewerOrHigher",
Expand Down
29 changes: 29 additions & 0 deletions src/neptune_scale/net/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional

from neptune_retrieval_api.models import SearchLeaderboardEntriesParamsDTO

from neptune_scale.exceptions import NeptuneScaleError
from neptune_scale.net.api_client import HostedApiClient
from neptune_scale.net.util import escape_nql_criterion
from neptune_scale.sync.util import ensure_api_token


def run_exists(project: str, run_id: str, api_token: Optional[str] = None) -> bool:
"""Query the backend for the existence of a Run with the given ID.
Returns True if the Run exists, False otherwise.
"""

client = HostedApiClient(api_token=ensure_api_token(api_token))
body = SearchLeaderboardEntriesParamsDTO.from_dict(
{
"query": {"query": f'`sys/custom_run_id`:string = "{escape_nql_criterion(run_id)}"'},
}
)

try:
result = client.search_entries(project, body)
except Exception as e:
raise NeptuneScaleError(reason=e)

return bool(result.entries)
6 changes: 6 additions & 0 deletions src/neptune_scale/net/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def escape_nql_criterion(criterion: str) -> str:
"""
Escape backslash and (double-)quotes in the string, to match what the NQL engine expects.
"""

return criterion.replace("\\", r"\\").replace('"', r"\"")
62 changes: 15 additions & 47 deletions src/neptune_scale/sync/aggregating_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def commit(self) -> None:
def get(self) -> BatchedOperations:
start = time.monotonic()

batch_operations: dict[Optional[float], RunOperation] = {}
batch_operations: list[RunOperation] = []
batch_sequence_id: Optional[int] = None
batch_timestamp: Optional[float] = None

Expand All @@ -95,7 +95,7 @@ def get(self) -> BatchedOperations:
if not batch_operations:
new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
batch_operations[element.batch_key] = new_operation
batch_operations.append(new_operation)
batch_bytes += len(element.operation)
else:
if not element.is_batchable:
Expand All @@ -110,10 +110,7 @@ def get(self) -> BatchedOperations:

new_operation = RunOperation()
new_operation.ParseFromString(element.operation)
if element.batch_key not in batch_operations:
batch_operations[element.batch_key] = new_operation
else:
merge_run_operation(batch_operations[element.batch_key], new_operation)
batch_operations.append(new_operation)
batch_bytes += element.metadata_size

batch_sequence_id = element.sequence_id
Expand Down Expand Up @@ -157,54 +154,25 @@ def get(self) -> BatchedOperations:
)


def create_run_batch(operations: dict[Optional[float], RunOperation]) -> RunOperation:
def create_run_batch(operations: list[RunOperation]) -> RunOperation:
if not operations:
raise Empty

if len(operations) == 1:
return next(iter(operations.values()))
return operations[0]

batch = None
for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])):
if batch is None:
batch = RunOperation()
batch.project = operation.project
batch.run_id = operation.run_id
batch.create_missing_project = operation.create_missing_project
batch.api_key = operation.api_key
head = operations[0]
batch = RunOperation()
batch.project = head.project
batch.run_id = head.run_id
batch.create_missing_project = head.create_missing_project
batch.api_key = head.api_key

for operation in operations:
operation_type = operation.WhichOneof("operation")
if operation_type == "update":
batch.update_batch.snapshots.append(operation.update)
else:
raise ValueError("Cannot batch operation of type %s", operation_type)

if batch is None:
raise Empty
return batch


def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None:
"""
Merge the `operation` into `batch`, taking into account the special case of `modify_sets`.
Protobuf merges existing map keys by simply overwriting values, instead of calling
`MergeFrom` on the existing value, eg: A['foo'] = B['foo'].
We want this instead:
batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}}
operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}}
result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}}
If we called `batch.MergeFrom(operation)` we would get an overwritten value:
result = {'sys/tags': 'string': { 'values': {'bar': ADD}}}
This function ensures that the `modify_sets` are merged correctly, leaving the default
behaviour for all other fields.
"""

modify_sets = operation.update.modify_sets
operation.update.ClearField("modify_sets")

batch.MergeFrom(operation)

for k, v in modify_sets.items():
batch.update.modify_sets[k].MergeFrom(v)
3 changes: 1 addition & 2 deletions src/neptune_scale/sync/operations_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]:
with self._lock:
return self._last_timestamp

def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None:
def enqueue(self, *, operation: RunOperation, size: Optional[int] = None) -> None:
try:
is_metadata_update = operation.HasField("update")
serialized_operation = operation.SerializeToString()
Expand All @@ -75,7 +75,6 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O
operation=serialized_operation,
metadata_size=size,
is_batchable=is_metadata_update,
batch_key=key,
),
block=True,
timeout=None,
Expand Down
2 changes: 0 additions & 2 deletions src/neptune_scale/sync/queue_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ class SingleOperation(NamedTuple):
is_batchable: bool
# Size of the metadata in the operation (without project, family, run_id etc.)
metadata_size: Optional[int]
# Update metadata key
batch_key: Optional[float]
2 changes: 2 additions & 0 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:

def work(self) -> None:
try:
# TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes,
# so we could just pass around instances of RunOperation
while (operation := self.get_next()) is not None:
sequence_id, timestamp, data = operation

Expand Down
Loading

0 comments on commit b6e03cd

Please sign in to comment.