diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index 1bbcccf..bd0bcc8 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -75,8 +75,10 @@ from neptune_scale.exceptions import ( NeptuneConnectionLostError, NeptuneInvalidCredentialsError, + NeptuneScaleError, NeptuneUnableToAuthenticateError, ) +from neptune_scale.net.util import raise_for_http_status from neptune_scale.sync.parameters import REQUEST_TIMEOUT from neptune_scale.util.abstract import Resource from neptune_scale.util.envs import ALLOW_SELF_SIGNED_CERTIFICATE @@ -106,11 +108,25 @@ def get_config_and_token_urls( verify_ssl=verify_ssl, timeout=Timeout(timeout=REQUEST_TIMEOUT), ) as client: - config = get_client_config.sync(client=client) - if config is None or isinstance(config, Error): - raise RuntimeError(f"Failed to get client config: {config}") - response = client.get_httpx_client().get(config.security.open_id_discovery) - token_urls = TokenRefreshingURLs.from_dict(response.json()) + response = get_client_config.sync_detailed(client=client) + if response.parsed is None: + raise NeptuneScaleError( + message="Failed to initialize API client: invalid response from server. " + f"Status code={response.status_code}" + ) + + if response.status_code != 200 or not isinstance(response.parsed, ClientConfig): + error = response.parsed if isinstance(response.parsed, Error) else None + + if response.status_code == 400 and error and isinstance(error.message, str): + if "X-Neptune-Api-Token" in error.message: + raise NeptuneInvalidCredentialsError() + + raise_for_http_status(response.status_code) + + config = cast(ClientConfig, response.parsed) + token_data = client.get_httpx_client().get(config.security.open_id_discovery) + token_urls = TokenRefreshingURLs.from_dict(token_data.json()) return config, token_urls @@ -144,7 +160,10 @@ def search_entries( class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: - credentials = Credentials.from_api_key(api_key=api_token) + try: + credentials = Credentials.from_api_key(api_key=api_token) + except InvalidApiTokenException: + raise NeptuneInvalidCredentialsError() verify_ssl: bool = os.environ.get(ALLOW_SELF_SIGNED_CERTIFICATE, "False").lower() in ("false", "0") diff --git a/src/neptune_scale/net/util.py b/src/neptune_scale/net/util.py index beca024..21ed111 100644 --- a/src/neptune_scale/net/util.py +++ b/src/neptune_scale/net/util.py @@ -13,6 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from neptune_scale.exceptions import ( + NeptuneConnectionLostError, + NeptuneInternalServerError, + NeptuneTooManyRequestsResponseError, + NeptuneUnauthorizedError, + NeptuneUnexpectedResponseError, +) +from neptune_scale.util import get_logger + +logger = get_logger() + def escape_nql_criterion(criterion: str) -> str: """ @@ -20,3 +31,19 @@ def escape_nql_criterion(criterion: str) -> str: """ return criterion.replace("\\", r"\\").replace('"', r"\"") + + +def raise_for_http_status(status_code: int) -> None: + assert status_code >= 400, f"Status code {status_code} is not an error" + + logger.error("HTTP response error: %s", status_code) + if status_code == 403: + raise NeptuneUnauthorizedError() + elif status_code == 408: + raise NeptuneConnectionLostError() + elif status_code == 429: + raise NeptuneTooManyRequestsResponseError() + elif status_code // 100 == 5: + raise NeptuneInternalServerError() + else: + raise NeptuneUnexpectedResponseError() diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index e68c7ab..f7c165f 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -54,7 +54,6 @@ NeptuneAttributeTypeUnsupported, NeptuneConnectionLostError, NeptuneFloatValueNanInfUnsupported, - NeptuneInternalServerError, NeptuneProjectInvalidName, NeptuneProjectNotFound, NeptuneRetryableError, @@ -70,16 +69,14 @@ NeptuneStringSetExceedsSizeLimit, NeptuneStringValueExceedsSizeLimit, NeptuneSynchronizationStopped, - NeptuneTooManyRequestsResponseError, - NeptuneUnauthorizedError, NeptuneUnexpectedError, - NeptuneUnexpectedResponseError, ) from neptune_scale.net.api_client import ( ApiClient, backend_factory, with_api_errors_handling, ) +from neptune_scale.net.util import raise_for_http_status from neptune_scale.sync.aggregating_queue import AggregatingQueue from neptune_scale.sync.errors_tracking import ErrorsQueue from neptune_scale.sync.parameters import ( @@ -438,7 +435,7 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: status_code = response.status_code if status_code != 200: - _raise_exception(status_code) + raise_for_http_status(status_code) return response.parsed @@ -482,20 +479,6 @@ def work(self) -> None: raise NeptuneSynchronizationStopped() from e -def _raise_exception(status_code: int) -> None: - logger.error("HTTP response error: %s", status_code) - if status_code == 403: - raise NeptuneUnauthorizedError() - elif status_code == 408: - raise NeptuneConnectionLostError() - elif status_code == 429: - raise NeptuneTooManyRequestsResponseError() - elif status_code // 100 == 5: - raise NeptuneInternalServerError() - else: - raise NeptuneUnexpectedResponseError() - - class StatusTrackingThread(Daemon, WithResources): def __init__( self, @@ -542,7 +525,7 @@ def check_batch(self, *, request_ids: list[str]) -> Optional[BulkRequestStatus]: status_code = response.status_code if status_code != 200: - _raise_exception(status_code) + raise_for_http_status(status_code) return response.parsed