Skip to content

Commit

Permalink
Refactor UserAgent setup for extensibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dnnanuti committed Feb 29, 2024
1 parent 7ac85ef commit 01dc22a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 6 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

### New features

### Bug Fixes
### Bug Fixes / Improvements
* Fix deadlock when enabling CRT debug logs. Removed former experimental method _enable_debug_logging().

* Refactor User-Agent setup for extensibility.

## v1.1.4 (February 26, 2024)

Expand Down
11 changes: 8 additions & 3 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional, Any

from s3torchconnector import S3Reader, S3Writer
from s3torchconnector._version import user_agent_prefix

from s3torchconnectorclient._mountpoint_s3_client import (
MountpointS3Client,
Expand All @@ -16,6 +15,7 @@
GetObjectStream,
)

from s3torchconnector._user_agent import UserAgent

"""
_s3client.py
Expand All @@ -32,11 +32,12 @@ def _identity(obj: Any) -> Any:


class S3Client:
def __init__(self, region: str, endpoint: str = None):
def __init__(self, region: str, endpoint: str = None, user_agent: UserAgent = None):
self._region = region
self._endpoint = endpoint
self._real_client = None
self._client_pid = None
self._user_agent = user_agent or UserAgent()

@property
def _client(self) -> MountpointS3Client:
Expand All @@ -50,11 +51,15 @@ def _client(self) -> MountpointS3Client:
def region(self) -> str:
return self._region

@property
def user_agent(self) -> str:
return self._user_agent.prefix

def _client_builder(self) -> MountpointS3Client:
return MountpointS3Client(
region=self._region,
endpoint=self._endpoint,
user_agent_prefix=user_agent_prefix,
user_agent_prefix=self._user_agent.prefix,
)

def get_object(
Expand Down
24 changes: 24 additions & 0 deletions s3torchconnector/src/s3torchconnector/_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from typing import List

from ._version import __version__

# https://www.rfc-editor.org/rfc/rfc9110#name-user-agent


class UserAgent:
def __init__(self, comments: List[str] = None):
self._user_agent_prefix = f"{__package__}/{__version__}"
self._comments = comments or []
self._comments_separator = "; "

def add_comment(self, comment):
self._comments.append(comment)

@property
def prefix(self):
comments_str = f"{self._comments_separator.join(self._comments)}" or ""
if comments_str != "":
return f"{self._user_agent_prefix} ({comments_str})"
return self._user_agent_prefix
1 change: 0 additions & 1 deletion s3torchconnector/src/s3torchconnector/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@

# __package__ is 's3torchconnector'
__version__ = importlib.metadata.version(__package__)
user_agent_prefix = f"{__package__}/{__version__}"
17 changes: 17 additions & 0 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

from s3torchconnector._user_agent import UserAgent
from s3torchconnector._version import __version__
from s3torchconnector._s3client import S3Client, MockS3Client

TEST_BUCKET = "test-bucket"
Expand Down Expand Up @@ -48,3 +50,18 @@ def test_list_objects_log(s3_client: S3Client, caplog):
with caplog.at_level(logging.DEBUG):
s3_client.list_objects(TEST_BUCKET, TEST_KEY)
assert f"ListObjects {S3_URI}" in caplog.messages


def test_s3_client_default_user_agent():
s3_client = S3Client(region=TEST_REGION)
assert s3_client.user_agent == f"s3torchconnector/{__version__}"


def test_s3_client_custom_user_agent():
s3_client = S3Client(
region=TEST_REGION, user_agent=UserAgent(["component/version", "metadata"])
)
assert (
s3_client.user_agent
== f"s3torchconnector/{__version__} (component/version; metadata)"
)
33 changes: 33 additions & 0 deletions s3torchconnector/tst/unit/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

from s3torchconnector._version import __version__
from s3torchconnector._user_agent import UserAgent


def test_default_user_agent():
user_agent = UserAgent()
assert user_agent.prefix == f"s3torchconnector/{__version__}"


def test_user_agent_creation_with_empty_comments():
user_agent = UserAgent([])
assert user_agent.prefix == f"s3torchconnector/{__version__}"


def test_user_agent_creation_with_empty_str_comments():
user_agent = UserAgent([""])
assert user_agent.prefix == f"s3torchconnector/{__version__}"


def test_user_agent_creation_with_none_comments():
user_agent = UserAgent(None)
assert user_agent.prefix == f"s3torchconnector/{__version__}"


def test_user_agent_creation_with_comments():
user_agent = UserAgent(["component/version", "metadata"])
assert (
user_agent.prefix
== f"s3torchconnector/{__version__} (component/version; metadata)"
)

0 comments on commit 01dc22a

Please sign in to comment.