diff --git a/CHANGELOG.md b/CHANGELOG.md index b37242a4..1b574190 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ ## Unreleased ### New features -* Expose `throughput_target_gbps` and `part_size` configurations of the inner S3 client. +* Expose a new class, S3ClientConfig, with `throughput_target_gbps` and `part_size` parameters of the inner S3 client. ## v1.2.1 (March 14, 2024) diff --git a/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py index fa712ef1..723ac63e 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py @@ -6,11 +6,9 @@ @dataclass(frozen=True) class S3ClientConfig: """A dataclass exposing configurable parameters for the S3 client. - Returns a config wrapper object. Args: throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach. - You can also use get_recommended_throughput_target_gbps() to get recommended value for your system. 10.0 Gbps by default (may change in future). part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in. Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits. diff --git a/s3torchconnector/tst/unit/test_s3_client.py b/s3torchconnector/tst/unit/test_s3_client.py index 504d3759..446eb8be 100644 --- a/s3torchconnector/tst/unit/test_s3_client.py +++ b/s3torchconnector/tst/unit/test_s3_client.py @@ -3,7 +3,7 @@ import logging import pytest -from hypothesis import given +from hypothesis import given, example from hypothesis.strategies import lists, text, integers, floats from unittest.mock import MagicMock @@ -18,6 +18,10 @@ TEST_REGION = "us-east-1" S3_URI = f"s3://{TEST_BUCKET}/{TEST_KEY}" +KiB = 1 << 10 +MiB = 1 << 20 +GiB = 1 << 30 + @pytest.fixture def s3_client() -> S3Client: @@ -87,12 +91,13 @@ def test_user_agent_always_starts_with_package_version(comments): @given( - part_size=integers(min_value=5 * 1024, max_value=5 * 1024 * 1024), + part_size=integers(min_value=5 * MiB, max_value=5 * GiB), throughput_target_gbps=floats(min_value=10.0, max_value=100.0), ) +@example(part_size=5 * MiB, throughput_target_gbps=10.0) +@example(part_size=5 * GiB, throughput_target_gbps=15.0) def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float): # Part size must have values between 5MiB and 5GiB - part_size = part_size * 1024 s3_client = S3Client( region=TEST_REGION, s3client_config=S3ClientConfig( @@ -101,17 +106,18 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float): ), ) assert s3_client._client.part_size == part_size - assert abs(s3_client._client.throughput_target_gbps - throughput_target_gbps) < 1e-9 + assert s3_client._client.throughput_target_gbps == throughput_target_gbps -def test_s3_client_invalid_part_size_config(): +@pytest.mark.parametrize("part_size", [1, 2 * KiB, 6 * GiB]) +def test_s3_client_invalid_part_size_config(part_size: int): with pytest.raises( S3Exception, match="invalid configuration: part size must be at between 5MiB and 5GiB", ): s3_client = S3Client( region=TEST_REGION, - s3client_config=S3ClientConfig(part_size=1), + s3client_config=S3ClientConfig(part_size=part_size), ) # The client is lazily initialized assert s3_client._client.part_size is not None diff --git a/s3torchconnector/tst/unit/test_s3_client_config.py b/s3torchconnector/tst/unit/test_s3_client_config.py index 6c2ff71e..037628d3 100644 --- a/s3torchconnector/tst/unit/test_s3_client_config.py +++ b/s3torchconnector/tst/unit/test_s3_client_config.py @@ -1,42 +1,41 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD -from hypothesis import given +from hypothesis import given, example from hypothesis.strategies import integers, floats from s3torchconnector import S3ClientConfig +from .test_s3_client import MiB, GiB def test_default(): config = S3ClientConfig() - assert config is not None - assert config.part_size == 8 * 1024 * 1024 - assert abs(config.throughput_target_gbps - 10.0) < 1e-9 + assert config.part_size == 8 * MiB + assert config.throughput_target_gbps == 10.0 -@given(part_size=integers(min_value=1, max_value=1e12)) +@given(part_size=integers(min_value=5 * MiB, max_value=5 * GiB)) def test_part_size_setup(part_size: int): config = S3ClientConfig(part_size=part_size) - assert config is not None assert config.part_size == part_size - assert abs(config.throughput_target_gbps - 10.0) < 1e-9 + assert config.throughput_target_gbps == 10.0 @given(throughput_target_gbps=floats(min_value=1.0, max_value=100.0)) def test_throughput_target_gbps_setup(throughput_target_gbps: float): config = S3ClientConfig(throughput_target_gbps=throughput_target_gbps) - assert config is not None assert config.part_size == 8 * 1024 * 1024 - assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9 + assert config.throughput_target_gbps == throughput_target_gbps @given( - part_size=integers(min_value=1, max_value=1e12), + part_size=integers(min_value=5 * MiB, max_value=5 * GiB), throughput_target_gbps=floats(min_value=1.0, max_value=100.0), ) +@example(part_size=5 * MiB, throughput_target_gbps=10.0) +@example(part_size=5 * GiB, throughput_target_gbps=15.0) def test_custom_setup(part_size: int, throughput_target_gbps: float): config = S3ClientConfig( part_size=part_size, throughput_target_gbps=throughput_target_gbps ) - assert config is not None assert config.part_size == part_size - assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9 + assert config.throughput_target_gbps == throughput_target_gbps