Skip to content

Commit

Permalink
Direct sync (#10)
Browse files Browse the repository at this point in the history
* Direct sync

* Tests fixed
  • Loading branch information
Raalsky authored Jul 31, 2024
1 parent 40c66a8 commit 5c4f200
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 30 deletions.
13 changes: 9 additions & 4 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import ApiClient
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
Expand Down Expand Up @@ -115,14 +116,14 @@ def __init__(
verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH)

self._project: str = project
self._api_token: str = api_token
self._family: str = family
self._run_id: str = run_id

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
)
self._backend: ApiClient = ApiClient(api_token=api_token)

if not resume:
self._create_run(
Expand All @@ -137,7 +138,7 @@ def __enter__(self) -> Run:

@property
def resources(self) -> tuple[Resource, ...]:
return (self._operations_queue,)
return self._operations_queue, self._backend

def close(self) -> None:
"""
Expand Down Expand Up @@ -168,7 +169,9 @@ def _create_run(
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._operations_queue.enqueue(operation=operation)
self._backend.submit(operation=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)

def log(
self,
Expand Down Expand Up @@ -229,4 +232,6 @@ def log(
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)
self._backend.submit(operation=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
Empty file.
87 changes: 87 additions & 0 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

__all__ = ["ApiClient"]


from dataclasses import dataclass

from neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_api.api.backend import get_client_config
from neptune_api.api.data_ingestion import submit_operation
from neptune_api.auth_helpers import exchange_api_key
from neptune_api.credentials import Credentials
from neptune_api.models import (
ClientConfig,
Error,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.core.components.abstract import Resource


class ApiClient(Resource):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)
config, token_urls = get_config_and_token_urls(credentials=credentials)
self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls)

def submit(self, operation: RunOperation) -> None:
_ = submit_operation.sync(client=self._backend, body=operation)

def cleanup(self) -> None:
pass

def close(self) -> None:
self._backend.__exit__()


@dataclass
class TokenRefreshingURLs:
authorization_endpoint: str
token_endpoint: str

@classmethod
def from_dict(cls, data: dict) -> TokenRefreshingURLs:
return TokenRefreshingURLs(
authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"]
)


def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig, TokenRefreshingURLs]:
with Client(base_url=credentials.base_url) as client:
config = get_client_config.sync(client=client)
if config is None or isinstance(config, Error):
raise RuntimeError(f"Failed to get client config: {config}")
response = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(response.json())
return config, token_urls


def create_auth_api_client(
*, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs
) -> AuthenticatedClient:
return AuthenticatedClient(
base_url=credentials.base_url,
credentials=credentials,
client_id=config.security.client_id,
token_refreshing_endpoint=token_refreshing_urls.token_endpoint,
api_key_exchange_callback=exchange_api_key,
)
74 changes: 48 additions & 26 deletions tests/unit/test_run.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import base64
import json
import uuid
from datetime import datetime
from unittest.mock import patch

import pytest
from freezegun import freeze_time

from neptune_scale import Run


def test_context_manager():
@pytest.fixture(scope="session")
def api_token():
return base64.b64encode(json.dumps({"api_address": "aa", "api_url": "bb"}).encode("utf-8")).decode("utf-8")


class MockedApiClient:
def __init__(self, *args, **kwargs) -> None:
pass

def submit(self, operation) -> None:
pass

def close(self) -> None:
pass

def cleanup(self) -> None:
pass


@patch("neptune_scale.ApiClient", MockedApiClient)
def test_context_manager(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -22,10 +44,10 @@ def test_context_manager():
assert True


def test_close():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_close(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -39,10 +61,10 @@ def test_close():
assert True


def test_family_too_long():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_family_too_long(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())

# and
Expand All @@ -54,10 +76,10 @@ def test_family_too_long():
...


def test_run_id_too_long():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_run_id_too_long(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
family = str(uuid.uuid4())

# and
Expand All @@ -69,9 +91,9 @@ def test_run_id_too_long():
...


def test_invalid_project_name():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_invalid_project_name(api_token):
# given
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -84,10 +106,10 @@ def test_invalid_project_name():
...


def test_metadata():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_metadata(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand Down Expand Up @@ -115,10 +137,10 @@ def test_metadata():
)


def test_log_without_step():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_log_without_step(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -132,10 +154,10 @@ def test_log_without_step():
)


def test_log_step_float():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_log_step_float(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -150,10 +172,10 @@ def test_log_step_float():
)


def test_log_no_timestamp():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_log_no_timestamp(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -167,10 +189,10 @@ def test_log_no_timestamp():
)


def test_resume():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_resume(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -182,11 +204,11 @@ def test_resume():
assert True


@patch("neptune_scale.ApiClient", MockedApiClient)
@freeze_time("2024-07-30 12:12:12.000022")
def test_creation_time():
def test_creation_time(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -198,10 +220,10 @@ def test_creation_time():
assert True


def test_assign_experiment():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_assign_experiment(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand All @@ -213,10 +235,10 @@ def test_assign_experiment():
assert True


def test_forking():
@patch("neptune_scale.ApiClient", MockedApiClient)
def test_forking(api_token):
# given
project = "workspace/project"
api_token = "API_TOKEN"
run_id = str(uuid.uuid4())
family = run_id

Expand Down

0 comments on commit 5c4f200

Please sign in to comment.