Skip to content

Commit

Permalink
Update comments and address PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
dnnanuti committed Mar 21, 2024
1 parent c85a0b9 commit c1b809d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Unreleased

### New features
* Expose `throughput_target_gbps` and `part_size` configurations of the inner S3 client.
* Expose a new class, S3ClientConfig, with `throughput_target_gbps` and `part_size` parameters of the inner S3 client.

## v1.2.1 (March 14, 2024)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
@dataclass(frozen=True)
class S3ClientConfig:
"""A dataclass exposing configurable parameters for the S3 client.
Returns a config wrapper object.
Args:
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach.
You can also use get_recommended_throughput_target_gbps() to get recommended value for your system.
10.0 Gbps by default (may change in future).
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits.
Expand Down
18 changes: 12 additions & 6 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import pytest

from hypothesis import given
from hypothesis import given, example
from hypothesis.strategies import lists, text, integers, floats
from unittest.mock import MagicMock

Expand All @@ -18,6 +18,10 @@
TEST_REGION = "us-east-1"
S3_URI = f"s3://{TEST_BUCKET}/{TEST_KEY}"

KiB = 1 << 10
MiB = 1 << 20
GiB = 1 << 30


@pytest.fixture
def s3_client() -> S3Client:
Expand Down Expand Up @@ -87,12 +91,13 @@ def test_user_agent_always_starts_with_package_version(comments):


@given(
part_size=integers(min_value=5 * 1024, max_value=5 * 1024 * 1024),
part_size=integers(min_value=5 * MiB, max_value=5 * GiB),
throughput_target_gbps=floats(min_value=10.0, max_value=100.0),
)
@example(part_size=5 * MiB, throughput_target_gbps=10.0)
@example(part_size=5 * GiB, throughput_target_gbps=15.0)
def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
# Part size must have values between 5MiB and 5GiB
part_size = part_size * 1024
s3_client = S3Client(
region=TEST_REGION,
s3client_config=S3ClientConfig(
Expand All @@ -101,17 +106,18 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
),
)
assert s3_client._client.part_size == part_size
assert abs(s3_client._client.throughput_target_gbps - throughput_target_gbps) < 1e-9
assert s3_client._client.throughput_target_gbps == throughput_target_gbps


def test_s3_client_invalid_part_size_config():
@pytest.mark.parametrize("part_size", [1, 2 * KiB, 6 * GiB])
def test_s3_client_invalid_part_size_config(part_size: int):
with pytest.raises(
S3Exception,
match="invalid configuration: part size must be at between 5MiB and 5GiB",
):
s3_client = S3Client(
region=TEST_REGION,
s3client_config=S3ClientConfig(part_size=1),
s3client_config=S3ClientConfig(part_size=part_size),
)
# The client is lazily initialized
assert s3_client._client.part_size is not None
23 changes: 11 additions & 12 deletions s3torchconnector/tst/unit/test_s3_client_config.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from hypothesis import given
from hypothesis import given, example
from hypothesis.strategies import integers, floats

from s3torchconnector import S3ClientConfig
from .test_s3_client import MiB, GiB


def test_default():
config = S3ClientConfig()
assert config is not None
assert config.part_size == 8 * 1024 * 1024
assert abs(config.throughput_target_gbps - 10.0) < 1e-9
assert config.part_size == 8 * MiB
assert config.throughput_target_gbps == 10.0


@given(part_size=integers(min_value=1, max_value=1e12))
@given(part_size=integers(min_value=5 * MiB, max_value=5 * GiB))
def test_part_size_setup(part_size: int):
config = S3ClientConfig(part_size=part_size)
assert config is not None
assert config.part_size == part_size
assert abs(config.throughput_target_gbps - 10.0) < 1e-9
assert config.throughput_target_gbps == 10.0


@given(throughput_target_gbps=floats(min_value=1.0, max_value=100.0))
def test_throughput_target_gbps_setup(throughput_target_gbps: float):
config = S3ClientConfig(throughput_target_gbps=throughput_target_gbps)
assert config is not None
assert config.part_size == 8 * 1024 * 1024
assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9
assert config.throughput_target_gbps == throughput_target_gbps


@given(
part_size=integers(min_value=1, max_value=1e12),
part_size=integers(min_value=5 * MiB, max_value=5 * GiB),
throughput_target_gbps=floats(min_value=1.0, max_value=100.0),
)
@example(part_size=5 * MiB, throughput_target_gbps=10.0)
@example(part_size=5 * GiB, throughput_target_gbps=15.0)
def test_custom_setup(part_size: int, throughput_target_gbps: float):
config = S3ClientConfig(
part_size=part_size, throughput_target_gbps=throughput_target_gbps
)
assert config is not None
assert config.part_size == part_size
assert abs(config.throughput_target_gbps - throughput_target_gbps) < 1e-9
assert config.throughput_target_gbps == throughput_target_gbps

0 comments on commit c1b809d

Please sign in to comment.