Skip to content

Commit

Permalink
Minimal Run creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Jul 31, 2024
1 parent 08f7e92 commit 0cbc6ca
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 50 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6))
- Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7))
- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8))
- Added support for `creation_time` ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9))
- Added support for Forking ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9))
- Added support for Experiments ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9))
- Added support for Run resume ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9))
82 changes: 79 additions & 3 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@
from datetime import datetime
from typing import Callable

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import ApiClient
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.message_builder import MessageBuilder
from neptune_scale.core.proto_utils import (
datetime_to_proto,
make_step,
)
from neptune_scale.core.validation import (
verify_collection_type,
verify_max_length,
Expand All @@ -43,6 +52,11 @@ def __init__(
api_token: str,
family: str,
run_id: str,
resume: bool = False,
as_experiment: str | None = None,
creation_time: datetime | None = None,
from_run_id: str | None = None,
from_step: int | float | None = None,
max_queue_size: int = MAX_QUEUE_SIZE,
max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None,
) -> None:
Expand All @@ -55,6 +69,11 @@ def __init__(
family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy.
Max length: 128 characters.
run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters.
resume: Whether to resume an existing run.
as_experiment: ID of the experiment to be associated with the run.
creation_time: Time when the run was created.
from_run_id: ID if the Run to fork from.
from_step: Step number to fork from.
max_queue_size: Maximum number of operations in a queue.
max_queue_size_exceeded_callback: Callback function triggered when a queue is full.
Accepts two arguments:
Expand All @@ -64,41 +83,96 @@ def __init__(
verify_type("api_token", api_token, str)
verify_type("family", family, str)
verify_type("run_id", run_id, str)
verify_type("resume", resume, bool)
verify_type("as_experiment", as_experiment, (str, type(None)))
verify_type("creation_time", creation_time, (datetime, type(None)))
verify_type("from_run_id", from_run_id, (str, type(None)))
verify_type("from_step", from_step, (int, float, type(None)))
verify_type("max_queue_size", max_queue_size, int)
verify_type("max_queue_size_exceeded_callback", max_queue_size_exceeded_callback, (Callable, type(None)))

if resume and creation_time is not None:
raise ValueError("`resume` and `creation_time` cannot be used together.")
if resume and as_experiment is not None:
raise ValueError("`resume` and `as_experiment` cannot be used together.")
if (from_run_id is not None and from_step is None) or (from_run_id is None and from_step is not None):
raise ValueError("`from_run_id` and `from_step` must be used together.")
if resume and from_run_id is not None:
raise ValueError("`resume` and `from_run_id` cannot be used together.")
if resume and from_step is not None:
raise ValueError("`resume` and `from_step` cannot be used together.")

verify_non_empty("api_token", api_token)
verify_non_empty("family", family)
verify_non_empty("run_id", run_id)
if as_experiment is not None:
verify_non_empty("as_experiment", as_experiment)
if from_run_id is not None:
verify_non_empty("from_run_id", from_run_id)

verify_project_qualified_name("project", project)

verify_max_length("family", family, MAX_FAMILY_LENGTH)
verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH)

self._project: str = project
self._api_token: str = api_token
self._family: str = family
self._run_id: str = run_id

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
)
self._backend: ApiClient = ApiClient(api_token=api_token)

if not resume:
self._create_run(
creation_time=datetime.now() if creation_time is None else creation_time,
as_experiment=as_experiment,
from_run_id=from_run_id,
from_step=from_step,
)

def __enter__(self) -> Run:
return self

@property
def resources(self) -> tuple[Resource, ...]:
return (self._operations_queue,)
return self._operations_queue, self._backend

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()

def _create_run(
self,
creation_time: datetime,
as_experiment: str | None,
from_run_id: str | None,
from_step: int | float | None,
) -> None:
fork_point: ForkPoint | None = None
if from_run_id is not None and from_step is not None:
fork_point = ForkPoint(
parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step)
)

operation = RunOperation(
project=self._project,
run_id=self._run_id,
create=CreateRun(
family=self._family,
fork_point=fork_point,
experiment_id=as_experiment,
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._backend.submit(operation=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)

def log(
self,
step: float | int | None = None,
Expand Down Expand Up @@ -158,4 +232,6 @@ def log(
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)
self._backend.submit(operation=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
Empty file.
87 changes: 87 additions & 0 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

__all__ = ["ApiClient"]


from dataclasses import dataclass

from neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_api.api.backend import get_client_config
from neptune_api.api.data_ingestion import submit_operation
from neptune_api.auth_helpers import exchange_api_key
from neptune_api.credentials import Credentials
from neptune_api.models import (
ClientConfig,
Error,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.components.abstract import Resource


class ApiClient(Resource):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)
config, token_urls = get_config_and_token_urls(credentials=credentials)
self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls)

def submit(self, operation: RunOperation) -> None:
_ = submit_operation.sync(client=self._backend, body=operation)

def cleanup(self) -> None:
pass

def close(self) -> None:
self._backend.__exit__()


@dataclass
class TokenRefreshingURLs:
authorization_endpoint: str
token_endpoint: str

@classmethod
def from_dict(cls, data: dict) -> TokenRefreshingURLs:
return TokenRefreshingURLs(
authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"]
)


def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig, TokenRefreshingURLs]:
with Client(base_url=credentials.base_url) as client:
config = get_client_config.sync(client=client)
if config is None or isinstance(config, Error):
raise RuntimeError(f"Failed to get client config: {config}")
response = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(response.json())
return config, token_urls


def create_auth_api_client(
*, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs
) -> AuthenticatedClient:
return AuthenticatedClient(
base_url=credentials.base_url,
credentials=credentials,
client_id=config.security.client_id,
token_refreshing_endpoint=token_refreshing_urls.token_endpoint,
api_key_exchange_callback=exchange_api_key,
)
34 changes: 5 additions & 29 deletions src/neptune_scale/core/message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from datetime import datetime

from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import (
SET_OPERATION,
ModifySet,
Step,
StringSet,
UpdateRunSnapshot,
Value,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.proto_utils import (
datetime_to_proto,
make_step,
)


class MessageBuilder:
def __init__(
Expand Down Expand Up @@ -91,30 +94,3 @@ def make_value(value: Value | float | str | int | bool | datetime | list[str] |
return fv
else:
raise ValueError(f"Unsupported ingest field value type: {type(value)}")


def datetime_to_proto(dt: datetime) -> Timestamp:
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.
Returns: Step protobuf used in Neptune API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)
35 changes: 35 additions & 0 deletions src/neptune_scale/core/proto_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

__all__ = ("datetime_to_proto", "make_step")

from datetime import datetime

from google.protobuf.timestamp_pb2 import Timestamp
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Step


def datetime_to_proto(dt: datetime) -> Timestamp:
dt_ts = dt.timestamp()
return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9))


def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step:
"""
Converts a number to protobuf Step value. Example:
>>> assert make_step(7.654321, True) == Step(whole=7, micro=654321)
Args:
number: step expressed as number
raise_on_step_precision_loss: inform converter whether it should silently drop precision and
round down to 6 decimal places or raise an error.
Returns: Step protobuf used in Neptune API.
"""
m = int(1e6)
micro: int = int(number * m)
if raise_on_step_precision_loss and number * m - micro != 0:
raise ValueError(f"step must not use more than 6-decimal points, got: {number}")

whole = micro // m
micro = micro % m

return Step(whole=whole, micro=micro)
Loading

0 comments on commit 0cbc6ca

Please sign in to comment.