Skip to content

Commit

Permalink
Direct sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Jul 31, 2024
1 parent 40c66a8 commit 6a5b911
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 10 deletions.
34 changes: 24 additions & 10 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@

__all__ = ["Run"]

import threading
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable

from neptune_api.api.data_ingestion import submit_operation
from neptune_api.credentials import Credentials
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import (
create_auth_api_client,
get_config_and_token_urls,
)
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.message_builder import MessageBuilder
from neptune_scale.core.proto_utils import (
datetime_to_proto,
Expand Down Expand Up @@ -115,14 +119,18 @@ def __init__(
verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH)

self._project: str = project
self._api_token: str = api_token
self._family: str = family
self._run_id: str = run_id

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
)
# TODO: Bring back the operations queue
# self._lock = threading.RLock()
# self._operations_queue: OperationsQueue = OperationsQueue(
# lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
# )

credentials = Credentials.from_api_key(api_key=api_token)
config, token_urls = get_config_and_token_urls(credentials=credentials)
self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls)

if not resume:
self._create_run(
Expand All @@ -137,13 +145,15 @@ def __enter__(self) -> Run:

@property
def resources(self) -> tuple[Resource, ...]:
return (self._operations_queue,)
# return (self._operations_queue,)
return ()

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()
self._backend.__exit__()

def _create_run(
self,
Expand All @@ -168,7 +178,9 @@ def _create_run(
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._operations_queue.enqueue(operation=operation)
_ = submit_operation.sync(client=self._backend, body=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)

def log(
self,
Expand Down Expand Up @@ -229,4 +241,6 @@ def log(
add_tags=add_tags,
remove_tags=remove_tags,
):
self._operations_queue.enqueue(operation=operation)
_ = submit_operation.sync(client=self._backend, body=operation)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
Empty file.
67 changes: 67 additions & 0 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

__all__ = ["TokenRefreshingURLs", "get_config_and_token_urls", "create_auth_api_client"]


from dataclasses import dataclass
from typing import Tuple

from neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_api.api.backend import get_client_config
from neptune_api.auth_helpers import exchange_api_key
from neptune_api.credentials import Credentials
from neptune_api.models import (
ClientConfig,
Error,
)


@dataclass
class TokenRefreshingURLs:
authorization_endpoint: str
token_endpoint: str

@classmethod
def from_dict(cls, data: dict) -> "TokenRefreshingURLs":
return TokenRefreshingURLs(
authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"]
)


def get_config_and_token_urls(*, credentials: Credentials) -> Tuple[ClientConfig, TokenRefreshingURLs]:
with Client(base_url=credentials.base_url) as client:
config = get_client_config.sync(client=client)
if config is None or isinstance(config, Error):
raise RuntimeError(f"Failed to get client config: {config}")
response = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(response.json())
return config, token_urls


def create_auth_api_client(
*, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs
) -> AuthenticatedClient:
return AuthenticatedClient(
base_url=credentials.base_url,
credentials=credentials,
client_id=config.security.client_id,
token_refreshing_endpoint=token_refreshing_urls.token_endpoint,
api_key_exchange_callback=exchange_api_key,
)

0 comments on commit 6a5b911

Please sign in to comment.