Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make get_client_config() raise errors that are more meaningful to the user #115

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/neptune_scale/net/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@
from neptune_scale.exceptions import (
NeptuneConnectionLostError,
NeptuneInvalidCredentialsError,
NeptuneScaleError,
NeptuneUnableToAuthenticateError,
)
from neptune_scale.net.util import raise_for_http_status
from neptune_scale.sync.parameters import REQUEST_TIMEOUT
from neptune_scale.util.abstract import Resource
from neptune_scale.util.envs import ALLOW_SELF_SIGNED_CERTIFICATE
Expand Down Expand Up @@ -106,11 +108,25 @@ def get_config_and_token_urls(
verify_ssl=verify_ssl,
timeout=Timeout(timeout=REQUEST_TIMEOUT),
) as client:
config = get_client_config.sync(client=client)
if config is None or isinstance(config, Error):
raise RuntimeError(f"Failed to get client config: {config}")
response = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(response.json())
response = get_client_config.sync_detailed(client=client)
if response.parsed is None:
raise NeptuneScaleError(
message="Failed to initialize API client: invalid response from server. "
f"Status code={response.status_code}"
)

if response.status_code != 200 or not isinstance(response.parsed, ClientConfig):
error = response.parsed if isinstance(response.parsed, Error) else None

if response.status_code == 400 and error and isinstance(error.message, str):
if "X-Neptune-Api-Token" in error.message:
raise NeptuneInvalidCredentialsError()

raise_for_http_status(response.status_code)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unable to parse this suggestion, there seems to be a problem with indentation, and it looks like it's not replacing all the lines it should replace. Can you double check?


config = cast(ClientConfig, response.parsed)
token_data = client.get_httpx_client().get(config.security.open_id_discovery)
token_urls = TokenRefreshingURLs.from_dict(token_data.json())
return config, token_urls


Expand Down Expand Up @@ -144,7 +160,10 @@ def search_entries(

class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
credentials = Credentials.from_api_key(api_key=api_token)
try:
credentials = Credentials.from_api_key(api_key=api_token)
except InvalidApiTokenException:
raise NeptuneInvalidCredentialsError()

verify_ssl: bool = os.environ.get(ALLOW_SELF_SIGNED_CERTIFICATE, "False").lower() in ("false", "0")

Expand Down
27 changes: 27 additions & 0 deletions src/neptune_scale/net/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from neptune_scale.exceptions import (
NeptuneConnectionLostError,
NeptuneInternalServerError,
NeptuneTooManyRequestsResponseError,
NeptuneUnauthorizedError,
NeptuneUnexpectedResponseError,
)
from neptune_scale.util import get_logger

logger = get_logger()


def escape_nql_criterion(criterion: str) -> str:
"""
Escape backslash and (double-)quotes in the string, to match what the NQL engine expects.
"""

return criterion.replace("\\", r"\\").replace('"', r"\"")


def raise_for_http_status(status_code: int) -> None:
assert status_code >= 400, f"Status code {status_code} is not an error"

logger.error("HTTP response error: %s", status_code)
if status_code == 403:
raise NeptuneUnauthorizedError()
elif status_code == 408:
raise NeptuneConnectionLostError()
elif status_code == 429:
raise NeptuneTooManyRequestsResponseError()
elif status_code // 100 == 5:
raise NeptuneInternalServerError()
else:
raise NeptuneUnexpectedResponseError()
23 changes: 3 additions & 20 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
NeptuneAttributeTypeUnsupported,
NeptuneConnectionLostError,
NeptuneFloatValueNanInfUnsupported,
NeptuneInternalServerError,
NeptuneProjectInvalidName,
NeptuneProjectNotFound,
NeptuneRetryableError,
Expand All @@ -70,16 +69,14 @@
NeptuneStringSetExceedsSizeLimit,
NeptuneStringValueExceedsSizeLimit,
NeptuneSynchronizationStopped,
NeptuneTooManyRequestsResponseError,
NeptuneUnauthorizedError,
NeptuneUnexpectedError,
NeptuneUnexpectedResponseError,
)
from neptune_scale.net.api_client import (
ApiClient,
backend_factory,
with_api_errors_handling,
)
from neptune_scale.net.util import raise_for_http_status
from neptune_scale.sync.aggregating_queue import AggregatingQueue
from neptune_scale.sync.errors_tracking import ErrorsQueue
from neptune_scale.sync.parameters import (
Expand Down Expand Up @@ -438,7 +435,7 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]:

status_code = response.status_code
if status_code != 200:
_raise_exception(status_code)
raise_for_http_status(status_code)

return response.parsed

Expand Down Expand Up @@ -482,20 +479,6 @@ def work(self) -> None:
raise NeptuneSynchronizationStopped() from e


def _raise_exception(status_code: int) -> None:
logger.error("HTTP response error: %s", status_code)
if status_code == 403:
raise NeptuneUnauthorizedError()
elif status_code == 408:
raise NeptuneConnectionLostError()
elif status_code == 429:
raise NeptuneTooManyRequestsResponseError()
elif status_code // 100 == 5:
raise NeptuneInternalServerError()
else:
raise NeptuneUnexpectedResponseError()


class StatusTrackingThread(Daemon, WithResources):
def __init__(
self,
Expand Down Expand Up @@ -542,7 +525,7 @@ def check_batch(self, *, request_ids: list[str]) -> Optional[BulkRequestStatus]:
status_code = response.status_code

if status_code != 200:
_raise_exception(status_code)
raise_for_http_status(status_code)

return response.parsed

Expand Down
Loading