diff --git a/CHANGELOG.md b/CHANGELOG.md index cead92cf..b37242a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## Unreleased + +### New features +* Expose `throughput_target_gbps` and `part_size` configurations of the inner S3 client. + ## v1.2.1 (March 14, 2024) ### Breaking changes diff --git a/s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py b/s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py index 46057b3c..85ac6d9a 100644 --- a/s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py +++ b/s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py @@ -75,7 +75,6 @@ def throughput(self): class ExperimentResultJsonEncoder(JSONEncoder): - def default(self, o: Any) -> Any: if isinstance(o, ExperimentResult): o: ExperimentResult = o diff --git a/s3torchconnector/src/s3torchconnector/__init__.py b/s3torchconnector/src/s3torchconnector/__init__.py index fec26aaf..d46ecb02 100644 --- a/s3torchconnector/src/s3torchconnector/__init__.py +++ b/s3torchconnector/src/s3torchconnector/__init__.py @@ -11,6 +11,7 @@ from .s3map_dataset import S3MapDataset from .s3checkpoint import S3Checkpoint from ._version import __version__ +from ._s3client import S3ClientConfig __all__ = [ "S3IterableDataset", @@ -19,5 +20,6 @@ "S3Reader", "S3Writer", "S3Exception", + "S3ClientConfig", "__version__", ] diff --git a/s3torchconnector/src/s3torchconnector/_s3client/__init__.py b/s3torchconnector/src/s3torchconnector/_s3client/__init__.py index b8f09a1e..552004a8 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/__init__.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/__init__.py @@ -1,7 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from .s3client_config import S3ClientConfig from ._s3client import S3Client from ._mock_s3client import MockS3Client -__all__ = ["S3Client", "MockS3Client"] +__all__ = [ + "S3ClientConfig", + "S3Client", + "MockS3Client", +] diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py index ceaf8ffb..edd92909 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py @@ -10,6 +10,7 @@ from . import S3Client from .._user_agent import UserAgent +from .s3client_config import S3ClientConfig """ _mock_s3client.py @@ -22,14 +23,19 @@ def __init__( self, region: str, bucket: str, - part_size: int = 8 * 1024 * 1024, user_agent: Optional[UserAgent] = None, + s3client_config: Optional[S3ClientConfig] = None, ): - super().__init__(region, user_agent=user_agent) + super().__init__( + region, + user_agent=user_agent, + s3client_config=s3client_config, + ) self._mock_client = MockMountpointS3Client( region, bucket, - part_size=part_size, + throughput_target_gbps=self.s3client_config.throughput_target_gbps, + part_size=self.s3client_config.part_size, user_agent_prefix=self.user_agent_prefix, ) diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py index 398065ae..d5f4dc8c 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py @@ -7,6 +7,7 @@ from typing import Optional, Any from s3torchconnector import S3Reader, S3Writer +from .s3client_config import S3ClientConfig from s3torchconnectorclient._mountpoint_s3_client import ( MountpointS3Client, @@ -35,8 +36,10 @@ class S3Client: def __init__( self, region: str, + *, endpoint: Optional[str] = None, user_agent: Optional[UserAgent] = None, + s3client_config: Optional[S3ClientConfig] = None, ): self._region = region self._endpoint = endpoint @@ -44,6 +47,7 @@ def __init__( self._client_pid: Optional[int] = None user_agent = user_agent or UserAgent() self._user_agent_prefix = user_agent.prefix + self._s3client_config = s3client_config or S3ClientConfig() @property def _client(self) -> MountpointS3Client: @@ -58,6 +62,10 @@ def _client(self) -> MountpointS3Client: def region(self) -> str: return self._region + @property + def s3client_config(self) -> S3ClientConfig: + return self._s3client_config + @property def user_agent_prefix(self) -> str: return self._user_agent_prefix @@ -67,6 +75,8 @@ def _client_builder(self) -> MountpointS3Client: region=self._region, endpoint=self._endpoint, user_agent_prefix=self._user_agent_prefix, + throughput_target_gbps=self._s3client_config.throughput_target_gbps, + part_size=self._s3client_config.part_size, ) def get_object( diff --git a/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py new file mode 100644 index 00000000..fa712ef1 --- /dev/null +++ b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py @@ -0,0 +1,23 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD +from dataclasses import dataclass + + +@dataclass(frozen=True) +class S3ClientConfig: + """A dataclass exposing configurable parameters for the S3 client. + Returns a config wrapper object. + + Args: + throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach. + You can also use get_recommended_throughput_target_gbps() to get recommended value for your system. + 10.0 Gbps by default (may change in future). + part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in. + Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits. + (max number of parts per upload is 10,000, minimum upload part size is 5 MiB). + Part size must have values between 5MiB and 5GiB. + 8MB by default (may change in future). + """ + + throughput_target_gbps: float = 10.0 + part_size: int = 8 * 1024 * 1024 diff --git a/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py b/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py index 09b6bf7e..b68937c8 100644 --- a/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py +++ b/s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py @@ -8,7 +8,7 @@ from lightning.pytorch.plugins.io import CheckpointIO -from .._s3client import S3Client +from .._s3client import S3Client, S3ClientConfig from .._s3dataset_common import parse_s3_uri from .._user_agent import UserAgent @@ -16,10 +16,18 @@ class S3LightningCheckpoint(CheckpointIO): """A checkpoint manager for S3 using the :class:`CheckpointIO` interface.""" - def __init__(self, region: str): + def __init__( + self, + region: str, + s3client_config: Optional[S3ClientConfig] = None, + ): self.region = region user_agent = UserAgent(["lightning", lightning.__version__]) - self._client = S3Client(region, user_agent=user_agent) + self._client = S3Client( + region, + user_agent=user_agent, + s3client_config=s3client_config, + ) def save_checkpoint( self, diff --git a/s3torchconnector/src/s3torchconnector/s3checkpoint.py b/s3torchconnector/src/s3torchconnector/s3checkpoint.py index 38ab9a27..d4ec9b27 100644 --- a/s3torchconnector/src/s3torchconnector/s3checkpoint.py +++ b/s3torchconnector/src/s3torchconnector/s3checkpoint.py @@ -3,7 +3,7 @@ from typing import Optional from ._s3dataset_common import parse_s3_uri -from ._s3client import S3Client +from ._s3client import S3Client, S3ClientConfig from . import S3Reader, S3Writer @@ -17,10 +17,17 @@ class S3Checkpoint: torch.load, and torch.save. """ - def __init__(self, region: str, endpoint: Optional[str] = None): + def __init__( + self, + region: str, + endpoint: Optional[str] = None, + s3client_config: Optional[S3ClientConfig] = None, + ): self.region = region self.endpoint = endpoint - self._client = S3Client(region, endpoint) + self._client = S3Client( + region, endpoint=endpoint, s3client_config=s3client_config + ) def reader(self, s3_uri: str) -> S3Reader: """Creates an S3Reader from a given s3_uri. diff --git a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py index 6df2e27e..d61487a8 100644 --- a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py @@ -8,7 +8,7 @@ from . import S3Reader from ._s3bucket_key_data import S3BucketKeyData -from ._s3client import S3Client +from ._s3client import S3Client, S3ClientConfig from ._s3dataset_common import ( identity, get_objects_from_uris, @@ -31,11 +31,13 @@ def __init__( get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]], endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint + self._s3client_config = s3client_config self._client = None @property @@ -54,6 +56,7 @@ def from_objects( region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): """Returns an instance of S3IterableDataset using the S3 URI(s) provided. @@ -62,6 +65,7 @@ def from_objects( region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. + s3client_config: Optional S3ClientConfig with parameters for S3 client. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. @@ -75,6 +79,7 @@ def from_objects( partial(get_objects_from_uris, object_uris), endpoint, transform=transform, + s3client_config=s3client_config, ) @classmethod @@ -85,6 +90,7 @@ def from_prefix( region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): """Returns an instance of S3IterableDataset using the S3 URI provided. @@ -93,6 +99,7 @@ def from_prefix( region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. + s3client_config: Optional S3ClientConfig with parameters for S3 client. Returns: S3IterableDataset: An IterableStyle dataset created from S3 objects. @@ -106,11 +113,16 @@ def from_prefix( partial(get_objects_from_prefix, s3_uri), endpoint, transform=transform, + s3client_config=s3client_config, ) def _get_client(self): if self._client is None: - self._client = S3Client(self.region, self.endpoint) + self._client = S3Client( + self.region, + endpoint=self.endpoint, + s3client_config=self._s3client_config, + ) return self._client def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any: diff --git a/s3torchconnector/src/s3torchconnector/s3map_dataset.py b/s3torchconnector/src/s3torchconnector/s3map_dataset.py index 163da2b0..6d6b837f 100644 --- a/s3torchconnector/src/s3torchconnector/s3map_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3map_dataset.py @@ -7,7 +7,7 @@ import torch.utils.data from s3torchconnector._s3bucket_key_data import S3BucketKeyData -from ._s3client import S3Client +from ._s3client import S3Client, S3ClientConfig from . import S3Reader from ._s3dataset_common import ( @@ -32,11 +32,13 @@ def __init__( get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]], endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint + self._s3client_config = s3client_config self._client = None self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None @@ -63,6 +65,7 @@ def from_objects( region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): """Returns an instance of S3MapDataset using the S3 URI(s) provided. @@ -71,6 +74,7 @@ def from_objects( region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. + s3client_config: Optional S3ClientConfig with parameters for S3 client. Returns: S3MapDataset: A Map-Style dataset created from S3 objects. @@ -84,6 +88,7 @@ def from_objects( partial(get_objects_from_uris, object_uris), endpoint, transform=transform, + s3client_config=s3client_config, ) @classmethod @@ -94,6 +99,7 @@ def from_prefix( region: str, endpoint: Optional[str] = None, transform: Callable[[S3Reader], Any] = identity, + s3client_config: Optional[S3ClientConfig] = None, ): """Returns an instance of S3MapDataset using the S3 URI provided. @@ -102,6 +108,7 @@ def from_prefix( region(str): AWS region of the S3 bucket where the objects are stored. endpoint(str): AWS endpoint of the S3 bucket where the objects are stored. transform: Optional callable which is used to transform an S3Reader into the desired type. + s3client_config: Optional S3ClientConfig with parameters for S3 client. Returns: S3MapDataset: A Map-Style dataset created from S3 objects. @@ -115,11 +122,16 @@ def from_prefix( partial(get_objects_from_prefix, s3_uri), endpoint, transform=transform, + s3client_config=s3client_config, ) def _get_client(self): if self._client is None: - self._client = S3Client(self.region, self.endpoint) + self._client = S3Client( + self.region, + endpoint=self.endpoint, + s3client_config=self._s3client_config, + ) return self._client def _get_object(self, i: int) -> S3Reader: diff --git a/s3torchconnector/src/s3torchconnector/s3reader.py b/s3torchconnector/src/s3torchconnector/s3reader.py index b7873270..b9deddc6 100644 --- a/s3torchconnector/src/s3torchconnector/s3reader.py +++ b/s3torchconnector/src/s3torchconnector/s3reader.py @@ -4,7 +4,7 @@ import io from functools import cached_property from io import SEEK_CUR, SEEK_END, SEEK_SET -from typing import Callable, Optional, Iterable, Iterator +from typing import Callable, Optional, Iterator from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo, GetObjectStream diff --git a/s3torchconnector/tst/unit/test_s3_client.py b/s3torchconnector/tst/unit/test_s3_client.py index 597b02d1..504d3759 100644 --- a/s3torchconnector/tst/unit/test_s3_client.py +++ b/s3torchconnector/tst/unit/test_s3_client.py @@ -4,12 +4,14 @@ import pytest from hypothesis import given -from hypothesis.strategies import lists, text +from hypothesis.strategies import lists, text, integers, floats from unittest.mock import MagicMock +from s3torchconnectorclient._mountpoint_s3_client import S3Exception + from s3torchconnector._user_agent import UserAgent from s3torchconnector._version import __version__ -from s3torchconnector._s3client import S3Client, MockS3Client +from s3torchconnector._s3client import S3Client, MockS3Client, S3ClientConfig TEST_BUCKET = "test-bucket" TEST_KEY = "test-key" @@ -82,3 +84,34 @@ def test_user_agent_always_starts_with_package_version(comments): if comments_str: assert comments_str in s3_client.user_agent_prefix assert comments_str in s3_client._client.user_agent_prefix + + +@given( + part_size=integers(min_value=5 * 1024, max_value=5 * 1024 * 1024), + throughput_target_gbps=floats(min_value=10.0, max_value=100.0), +) +def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float): + # Part size must have values between 5MiB and 5GiB + part_size = part_size * 1024 + s3_client = S3Client( + region=TEST_REGION, + s3client_config=S3ClientConfig( + part_size=part_size, + throughput_target_gbps=throughput_target_gbps, + ), + ) + assert s3_client._client.part_size == part_size + assert abs(s3_client._client.throughput_target_gbps - throughput_target_gbps) < 1e-9 + + +def test_s3_client_invalid_part_size_config(): + with pytest.raises( + S3Exception, + match="invalid configuration: part size must be at between 5MiB and 5GiB", + ): + s3_client = S3Client( + region=TEST_REGION, + s3client_config=S3ClientConfig(part_size=1), + ) + # The client is lazily initialized + assert s3_client._client.part_size is not None diff --git a/s3torchconnector/tst/unit/test_s3_client_config.py b/s3torchconnector/tst/unit/test_s3_client_config.py new file mode 100644 index 00000000..dfce2e43 --- /dev/null +++ b/s3torchconnector/tst/unit/test_s3_client_config.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD +from hypothesis import given +from hypothesis.strategies import integers, floats + +from s3torchconnector import S3ClientConfig + + +def test_default(): + config = S3ClientConfig() + assert config is not None + assert config.part_size == 8 * 1024 * 1024 + assert abs(config.throughput_target_gbps - 10.0) < 1e-9 + + +@given(part_size=integers(min_value=1, max_value=1e12)) +def test_part_size_setup(part_size: int): + config = S3ClientConfig(part_size=part_size) + assert config is not None + assert config.part_size == part_size + assert abs(config.throughput_target_gbps - 10.0) < 1e-9 + + +@given(throughput_target_gbps=floats(min_value=1.0, max_value=100.0)) +def test_throughput_target_gbps_setup(throughput_target_gbps: float): + config = S3ClientConfig(throughput_target_gbps=throughput_target_gbps) + assert config is not None + assert config.part_size == 8 * 1024 * 1024 + assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9 + +@given(part_size=integers(min_value=1, max_value=1e12), + throughput_target_gbps=floats(min_value=1.0, max_value=100.0)) +def test_custom_setup(part_size: int, throughput_target_gbps: float): + config = S3ClientConfig(part_size=part_size, throughput_target_gbps=throughput_target_gbps) + assert config is not None + assert config.part_size == part_size + assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9