Skip to content

Commit

Permalink
Expose S3ClientConfig
Browse files Browse the repository at this point in the history
We expose the following configuration flags with performance impact:
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps).
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
  • Loading branch information
dnnanuti committed Mar 21, 2024
1 parent 2da3601 commit 55766a1
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## Unreleased

### New features
* Expose `throughput_target_gbps` and `part_size` configurations of the inner S3 client.

## v1.2.1 (March 14, 2024)

### Breaking changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def throughput(self):


class ExperimentResultJsonEncoder(JSONEncoder):

def default(self, o: Any) -> Any:
if isinstance(o, ExperimentResult):
o: ExperimentResult = o
Expand Down
2 changes: 2 additions & 0 deletions s3torchconnector/src/s3torchconnector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .s3map_dataset import S3MapDataset
from .s3checkpoint import S3Checkpoint
from ._version import __version__
from ._s3client import S3ClientConfig

__all__ = [
"S3IterableDataset",
Expand All @@ -19,5 +20,6 @@
"S3Reader",
"S3Writer",
"S3Exception",
"S3ClientConfig",
"__version__",
]
7 changes: 6 additions & 1 deletion s3torchconnector/src/s3torchconnector/_s3client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

from .s3client_config import S3ClientConfig
from ._s3client import S3Client
from ._mock_s3client import MockS3Client

__all__ = ["S3Client", "MockS3Client"]
__all__ = [
"S3ClientConfig",
"S3Client",
"MockS3Client",
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from . import S3Client
from .._user_agent import UserAgent
from .s3client_config import S3ClientConfig

"""
_mock_s3client.py
Expand All @@ -22,14 +23,19 @@ def __init__(
self,
region: str,
bucket: str,
part_size: int = 8 * 1024 * 1024,
user_agent: Optional[UserAgent] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
super().__init__(region, user_agent=user_agent)
super().__init__(
region,
user_agent=user_agent,
s3client_config=s3client_config,
)
self._mock_client = MockMountpointS3Client(
region,
bucket,
part_size=part_size,
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
)

Expand Down
10 changes: 10 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Any

from s3torchconnector import S3Reader, S3Writer
from .s3client_config import S3ClientConfig

from s3torchconnectorclient._mountpoint_s3_client import (
MountpointS3Client,
Expand Down Expand Up @@ -35,15 +36,18 @@ class S3Client:
def __init__(
self,
region: str,
*,
endpoint: Optional[str] = None,
user_agent: Optional[UserAgent] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self._region = region
self._endpoint = endpoint
self._real_client: Optional[MountpointS3Client] = None
self._client_pid: Optional[int] = None
user_agent = user_agent or UserAgent()
self._user_agent_prefix = user_agent.prefix
self._s3client_config = s3client_config or S3ClientConfig()

@property
def _client(self) -> MountpointS3Client:
Expand All @@ -58,6 +62,10 @@ def _client(self) -> MountpointS3Client:
def region(self) -> str:
return self._region

@property
def s3client_config(self) -> S3ClientConfig:
return self._s3client_config

@property
def user_agent_prefix(self) -> str:
return self._user_agent_prefix
Expand All @@ -67,6 +75,8 @@ def _client_builder(self) -> MountpointS3Client:
region=self._region,
endpoint=self._endpoint,
user_agent_prefix=self._user_agent_prefix,
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
)

def get_object(
Expand Down
23 changes: 23 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from dataclasses import dataclass


@dataclass(frozen=True)
class S3ClientConfig:
"""A dataclass exposing configurable parameters for the S3 client.
Returns a config wrapper object.
Args:
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach.
You can also use get_recommended_throughput_target_gbps() to get recommended value for your system.
10.0 Gbps by default (may change in future).
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits.
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
Part size must have values between 5MiB and 5GiB.
8MB by default (may change in future).
"""

throughput_target_gbps: float = 10.0
part_size: int = 8 * 1024 * 1024
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,26 @@

from lightning.pytorch.plugins.io import CheckpointIO

from .._s3client import S3Client
from .._s3client import S3Client, S3ClientConfig
from .._s3dataset_common import parse_s3_uri
from .._user_agent import UserAgent


class S3LightningCheckpoint(CheckpointIO):
"""A checkpoint manager for S3 using the :class:`CheckpointIO` interface."""

def __init__(self, region: str):
def __init__(
self,
region: str,
s3client_config: Optional[S3ClientConfig] = None,
):
self.region = region
user_agent = UserAgent(["lightning", lightning.__version__])
self._client = S3Client(region, user_agent=user_agent)
self._client = S3Client(
region,
user_agent=user_agent,
s3client_config=s3client_config,
)

def save_checkpoint(
self,
Expand Down
13 changes: 10 additions & 3 deletions s3torchconnector/src/s3torchconnector/s3checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from ._s3dataset_common import parse_s3_uri
from ._s3client import S3Client
from ._s3client import S3Client, S3ClientConfig
from . import S3Reader, S3Writer


Expand All @@ -17,10 +17,17 @@ class S3Checkpoint:
torch.load, and torch.save.
"""

def __init__(self, region: str, endpoint: Optional[str] = None):
def __init__(
self,
region: str,
endpoint: Optional[str] = None,
s3client_config: Optional[S3ClientConfig] = None,
):
self.region = region
self.endpoint = endpoint
self._client = S3Client(region, endpoint)
self._client = S3Client(
region, endpoint=endpoint, s3client_config=s3client_config
)

def reader(self, s3_uri: str) -> S3Reader:
"""Creates an S3Reader from a given s3_uri.
Expand Down
16 changes: 14 additions & 2 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from . import S3Reader
from ._s3bucket_key_data import S3BucketKeyData
from ._s3client import S3Client
from ._s3client import S3Client, S3ClientConfig
from ._s3dataset_common import (
identity,
get_objects_from_uris,
Expand All @@ -31,11 +31,13 @@ def __init__(
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3client_config = s3client_config
self._client = None

@property
Expand All @@ -54,6 +56,7 @@ def from_objects(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
Expand All @@ -62,6 +65,7 @@ def from_objects(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -75,6 +79,7 @@ def from_objects(
partial(get_objects_from_uris, object_uris),
endpoint,
transform=transform,
s3client_config=s3client_config,
)

@classmethod
Expand All @@ -85,6 +90,7 @@ def from_prefix(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3IterableDataset using the S3 URI provided.
Expand All @@ -93,6 +99,7 @@ def from_prefix(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -106,11 +113,16 @@ def from_prefix(
partial(get_objects_from_prefix, s3_uri),
endpoint,
transform=transform,
s3client_config=s3client_config,
)

def _get_client(self):
if self._client is None:
self._client = S3Client(self.region, self.endpoint)
self._client = S3Client(
self.region,
endpoint=self.endpoint,
s3client_config=self._s3client_config,
)
return self._client

def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:
Expand Down
16 changes: 14 additions & 2 deletions s3torchconnector/src/s3torchconnector/s3map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.utils.data
from s3torchconnector._s3bucket_key_data import S3BucketKeyData

from ._s3client import S3Client
from ._s3client import S3Client, S3ClientConfig
from . import S3Reader

from ._s3dataset_common import (
Expand All @@ -32,11 +32,13 @@ def __init__(
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3client_config = s3client_config
self._client = None
self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None

Expand All @@ -63,6 +65,7 @@ def from_objects(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3MapDataset using the S3 URI(s) provided.
Expand All @@ -71,6 +74,7 @@ def from_objects(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand All @@ -84,6 +88,7 @@ def from_objects(
partial(get_objects_from_uris, object_uris),
endpoint,
transform=transform,
s3client_config=s3client_config,
)

@classmethod
Expand All @@ -94,6 +99,7 @@ def from_prefix(
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
):
"""Returns an instance of S3MapDataset using the S3 URI provided.
Expand All @@ -102,6 +108,7 @@ def from_prefix(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand All @@ -115,11 +122,16 @@ def from_prefix(
partial(get_objects_from_prefix, s3_uri),
endpoint,
transform=transform,
s3client_config=s3client_config,
)

def _get_client(self):
if self._client is None:
self._client = S3Client(self.region, self.endpoint)
self._client = S3Client(
self.region,
endpoint=self.endpoint,
s3client_config=self._s3client_config,
)
return self._client

def _get_object(self, i: int) -> S3Reader:
Expand Down
2 changes: 1 addition & 1 deletion s3torchconnector/src/s3torchconnector/s3reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
from functools import cached_property
from io import SEEK_CUR, SEEK_END, SEEK_SET
from typing import Callable, Optional, Iterable, Iterator
from typing import Callable, Optional, Iterator

from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo, GetObjectStream

Expand Down
Loading

0 comments on commit 55766a1

Please sign in to comment.