Skip to content

Commit

Permalink
Merge pull request #115 from neptune-ai/kg/fix-get-client-config
Browse files Browse the repository at this point in the history
Make `get_client_config()` raise errors that are more meaningful to the user
  • Loading branch information
kgodlewski authored Jan 14, 2025
2 parents 5d9fcbd + 8fd053a commit 667d7e1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
31 changes: 25 additions & 6 deletions src/neptune_scale/net/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@
from neptune_scale.exceptions import (
NeptuneConnectionLostError,
NeptuneInvalidCredentialsError,
NeptuneScaleError,
NeptuneUnableToAuthenticateError,
)
from neptune_scale.net.util import raise_for_http_status
from neptune_scale.sync.parameters import REQUEST_TIMEOUT
from neptune_scale.util.abstract import Resource
from neptune_scale.util.envs import ALLOW_SELF_SIGNED_CERTIFICATE
Expand Down Expand Up @@ -106,11 +108,25 @@ def get_config_and_token_urls(
verify_ssl=verify_ssl,
timeout=Timeout(timeout=REQUEST_TIMEOUT),
) as client:
config = get_client_config.sync(client=client)
if config is None or isinstance(config, Error):
raise RuntimeError(f"Failed to get client config: {config}")
response = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(response.json())
response = get_client_config.sync_detailed(client=client)
if response.parsed is None:
raise NeptuneScaleError(
message="Failed to initialize API client: invalid response from server. "
f"Status code={response.status_code}"
)

if response.status_code != 200 or not isinstance(response.parsed, ClientConfig):
error = response.parsed if isinstance(response.parsed, Error) else None

if response.status_code == 400 and error and isinstance(error.message, str):
if "X-Neptune-Api-Token" in error.message:
raise NeptuneInvalidCredentialsError()

raise_for_http_status(response.status_code)

config = cast(ClientConfig, response.parsed)
token_data = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(token_data.json())
return config, token_urls


Expand Down Expand Up @@ -144,7 +160,10 @@ def search_entries(

class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)
try:
credentials = Credentials.from_api_key(api_key=api_token)
except InvalidApiTokenException:
raise NeptuneInvalidCredentialsError()

verify_ssl: bool = os.environ.get(ALLOW_SELF_SIGNED_CERTIFICATE, "False").lower() in ("false", "0")

Expand Down
27 changes: 27 additions & 0 deletions src/neptune_scale/net/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from neptune_scale.exceptions import (
NeptuneConnectionLostError,
NeptuneInternalServerError,
NeptuneTooManyRequestsResponseError,
NeptuneUnauthorizedError,
NeptuneUnexpectedResponseError,
)
from neptune_scale.util import get_logger

logger = get_logger()


def escape_nql_criterion(criterion: str) -> str:
"""
Escape backslash and (double-)quotes in the string, to match what the NQL engine expects.
"""

return criterion.replace("\\", r"\\").replace('"', r"\"")


def raise_for_http_status(status_code: int) -> None:
assert status_code >= 400, f"Status code {status_code} is not an error"

logger.error("HTTP response error: %s", status_code)
if status_code == 403:
raise NeptuneUnauthorizedError()
elif status_code == 408:
raise NeptuneConnectionLostError()
elif status_code == 429:
raise NeptuneTooManyRequestsResponseError()
elif status_code // 100 == 5:
raise NeptuneInternalServerError()
else:
raise NeptuneUnexpectedResponseError()
23 changes: 3 additions & 20 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
NeptuneAttributeTypeUnsupported,
NeptuneConnectionLostError,
NeptuneFloatValueNanInfUnsupported,
NeptuneInternalServerError,
NeptuneProjectInvalidName,
NeptuneProjectNotFound,
NeptuneRetryableError,
Expand All @@ -70,16 +69,14 @@
NeptuneStringSetExceedsSizeLimit,
NeptuneStringValueExceedsSizeLimit,
NeptuneSynchronizationStopped,
NeptuneTooManyRequestsResponseError,
NeptuneUnauthorizedError,
NeptuneUnexpectedError,
NeptuneUnexpectedResponseError,
)
from neptune_scale.net.api_client import (
ApiClient,
backend_factory,
with_api_errors_handling,
)
from neptune_scale.net.util import raise_for_http_status
from neptune_scale.sync.aggregating_queue import AggregatingQueue
from neptune_scale.sync.errors_tracking import ErrorsQueue
from neptune_scale.sync.parameters import (
Expand Down Expand Up @@ -438,7 +435,7 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:

status_code = response.status_code
if status_code != 200:
_raise_exception(status_code)
raise_for_http_status(status_code)

return response.parsed

Expand Down Expand Up @@ -482,20 +479,6 @@ def work(self) -> None:
raise NeptuneSynchronizationStopped() from e


def _raise_exception(status_code: int) -> None:
logger.error("HTTP response error: %s", status_code)
if status_code == 403:
raise NeptuneUnauthorizedError()
elif status_code == 408:
raise NeptuneConnectionLostError()
elif status_code == 429:
raise NeptuneTooManyRequestsResponseError()
elif status_code // 100 == 5:
raise NeptuneInternalServerError()
else:
raise NeptuneUnexpectedResponseError()


class StatusTrackingThread(Daemon, WithResources):
def __init__(
self,
Expand Down Expand Up @@ -542,7 +525,7 @@ def check_batch(self, *, request_ids: list[str]) -> Optional[BulkRequestStatus]:
status_code = response.status_code

if status_code != 200:
_raise_exception(status_code)
raise_for_http_status(status_code)

return response.parsed

Expand Down

0 comments on commit 667d7e1

Please sign in to comment.