Skip to content

Commit

Permalink
Rename config class and move to s3torchconnector
Browse files Browse the repository at this point in the history
  • Loading branch information
dnnanuti committed Mar 21, 2024
1 parent 5ee1e6b commit a093ad0
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 87 deletions.
5 changes: 3 additions & 2 deletions s3torchconnector/src/s3torchconnector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from s3torchconnectorclient import S3Exception, S3ConnectorClientConfig
from s3torchconnectorclient import S3Exception

# The order of these imports is the same in which they will be rendered
# in the API docs generated with Sphinx.
Expand All @@ -11,6 +11,7 @@
from .s3map_dataset import S3MapDataset
from .s3checkpoint import S3Checkpoint
from ._version import __version__
from ._s3client import S3ClientConfig

__all__ = [
"S3IterableDataset",
Expand All @@ -19,6 +20,6 @@
"S3Reader",
"S3Writer",
"S3Exception",
"S3ConnectorClientConfig",
"S3ClientConfig",
"__version__",
]
7 changes: 6 additions & 1 deletion s3torchconnector/src/s3torchconnector/_s3client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

from .s3client_config import S3ClientConfig
from ._s3client import S3Client
from ._mock_s3client import MockS3Client

__all__ = ["S3Client", "MockS3Client"]
__all__ = [
"S3ClientConfig",
"S3Client",
"MockS3Client",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
MountpointS3Client,
)

from s3torchconnectorclient import S3ConnectorClientConfig
from . import S3Client
from .._user_agent import UserAgent
from .s3client_config import S3ClientConfig

"""
_mock_s3client.py
Expand All @@ -24,18 +24,18 @@ def __init__(
region: str,
bucket: str,
user_agent: Optional[UserAgent] = None,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
super().__init__(
region,
user_agent=user_agent,
s3_connector_client_config=s3_connector_client_config,
s3client_config=s3client_config,
)
self._mock_client = MockMountpointS3Client(
region,
bucket,
throughput_target_gbps=self.s3_connector_client_config.throughput_target_gbps,
part_size=self.s3_connector_client_config.part_size,
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
)

Expand Down
17 changes: 8 additions & 9 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Any

from s3torchconnector import S3Reader, S3Writer
from .s3client_config import S3ClientConfig

from s3torchconnectorclient._mountpoint_s3_client import (
MountpointS3Client,
Expand All @@ -16,7 +17,6 @@
)

from s3torchconnector._user_agent import UserAgent
from s3torchconnectorclient.s3_connector_client_config import S3ConnectorClientConfig

"""
_s3client.py
Expand All @@ -36,19 +36,18 @@ class S3Client:
def __init__(
self,
region: str,
*,
endpoint: Optional[str] = None,
user_agent: Optional[UserAgent] = None,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self._region = region
self._endpoint = endpoint
self._real_client: Optional[MountpointS3Client] = None
self._client_pid: Optional[int] = None
user_agent = user_agent or UserAgent()
self._user_agent_prefix = user_agent.prefix
self._s3_connector_client_config = (
s3_connector_client_config or S3ConnectorClientConfig()
)
self._s3client_config = s3client_config or S3ClientConfig()

@property
def _client(self) -> MountpointS3Client:
Expand All @@ -64,8 +63,8 @@ def region(self) -> str:
return self._region

@property
def s3_connector_client_config(self) -> S3ConnectorClientConfig:
return self._s3_connector_client_config
def s3client_config(self) -> S3ClientConfig:
return self._s3client_config

@property
def user_agent_prefix(self) -> str:
Expand All @@ -76,8 +75,8 @@ def _client_builder(self) -> MountpointS3Client:
region=self._region,
endpoint=self._endpoint,
user_agent_prefix=self._user_agent_prefix,
throughput_target_gbps=self._s3_connector_client_config.throughput_target_gbps,
part_size=self._s3_connector_client_config.part_size,
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
)

def get_object(
Expand Down
22 changes: 22 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from dataclasses import dataclass


@dataclass(frozen=True)
class S3ClientConfig:
"""A dataclass exposing configurable parameters for the S3 client.
Returns a config wrapper object.
Args:
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach.
You can also use get_recommended_throughput_target_gbps() to get recommended value for your system.
10.0 Gbps by default (may change in future).
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits.
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
8MB by default (may change in future).
"""

throughput_target_gbps: float = 10.0
part_size: int = 8 * 1024 * 1024
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from lightning.pytorch.plugins.io import CheckpointIO

from .._s3client import S3Client
from .._s3client import S3Client, S3ClientConfig
from .._s3dataset_common import parse_s3_uri
from .._user_agent import UserAgent
from s3torchconnectorclient.s3_connector_client_config import S3ConnectorClientConfig


class S3LightningCheckpoint(CheckpointIO):
Expand All @@ -20,14 +19,14 @@ class S3LightningCheckpoint(CheckpointIO):
def __init__(
self,
region: str,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self.region = region
user_agent = UserAgent(["lightning", lightning.__version__])
self._client = S3Client(
region,
user_agent=user_agent,
s3_connector_client_config=s3_connector_client_config,
s3client_config=s3client_config,
)

def save_checkpoint(
Expand Down
9 changes: 5 additions & 4 deletions s3torchconnector/src/s3torchconnector/s3checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from typing import Optional

from ._s3dataset_common import parse_s3_uri
from ._s3client import S3Client
from ._s3client import S3Client, S3ClientConfig
from . import S3Reader, S3Writer
from s3torchconnectorclient.s3_connector_client_config import S3ConnectorClientConfig


class S3Checkpoint:
Expand All @@ -22,11 +21,13 @@ def __init__(
self,
region: str,
endpoint: Optional[str] = None,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self.region = region
self.endpoint = endpoint
self._client = S3Client(region, endpoint, s3_connector_client_config)
self._client = S3Client(
region, endpoint=endpoint, s3client_config=s3client_config
)

def reader(self, s3_uri: str) -> S3Reader:
"""Creates an S3Reader from a given s3_uri.
Expand Down
23 changes: 12 additions & 11 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from . import S3Reader
from ._s3bucket_key_data import S3BucketKeyData
from ._s3client import S3Client
from ._s3client import S3Client, S3ClientConfig
from ._s3dataset_common import (
identity,
get_objects_from_uris,
get_objects_from_prefix,
)
from s3torchconnectorclient.s3_connector_client_config import S3ConnectorClientConfig

log = logging.getLogger(__name__)

Expand All @@ -32,13 +31,13 @@ def __init__(
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3_connector_client_config = s3_connector_client_config
self._s3client_config = s3client_config
self._client = None

@property
Expand All @@ -57,7 +56,7 @@ def from_objects(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
Expand All @@ -66,7 +65,7 @@ def from_objects(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -80,7 +79,7 @@ def from_objects(
partial(get_objects_from_uris, object_uris),
endpoint,
transform=transform,
s3_connector_client_config=s3_connector_client_config,
s3client_config=s3client_config,
)

@classmethod
Expand All @@ -91,7 +90,7 @@ def from_prefix(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3_connector_client_config: Optional[S3ConnectorClientConfig] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3IterableDataset using the S3 URI provided.
Expand All @@ -100,7 +99,7 @@ def from_prefix(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -114,13 +113,15 @@ def from_prefix(
partial(get_objects_from_prefix, s3_uri),
endpoint,
transform=transform,
s3_connector_client_config=s3_connector_client_config,
s3client_config=s3client_config,
)

def _get_client(self):
if self._client is None:
self._client = S3Client(
self.region, self.endpoint, self._s3_connector_client_config
self.region,
endpoint=self.endpoint,
s3client_config=self._s3client_config,
)
return self._client

Expand Down
Loading

0 comments on commit a093ad0

Please sign in to comment.