Skip to content

Commit

Permalink
Update tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dnnanuti committed Mar 21, 2024
1 parent 69e95ae commit 5ee1e6b
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 3 deletions.
2 changes: 2 additions & 0 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def from_objects(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand Down Expand Up @@ -99,6 +100,7 @@ def from_prefix(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand Down
2 changes: 2 additions & 0 deletions s3torchconnector/src/s3torchconnector/s3map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_objects(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand Down Expand Up @@ -108,6 +109,7 @@ def from_prefix(
region(str): AWS region of the S3 bucket where the objects are stored.
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3_connector_client_config: Optional S3ConnectorClientConfig with parameters for S3 client.
Returns:
S3MapDataset: A Map-Style dataset created from S3 objects.
Expand Down
36 changes: 35 additions & 1 deletion s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pytest

from hypothesis import given
from hypothesis.strategies import lists, text
from hypothesis.strategies import lists, text, integers, floats
from unittest.mock import MagicMock

from s3torchconnectorclient._mountpoint_s3_client import S3Exception

from s3torchconnector._user_agent import UserAgent
from s3torchconnector._version import __version__
from s3torchconnector._s3client import S3Client, MockS3Client
from s3torchconnectorclient import S3ConnectorClientConfig

TEST_BUCKET = "test-bucket"
TEST_KEY = "test-key"
Expand Down Expand Up @@ -82,3 +85,34 @@ def test_user_agent_always_starts_with_package_version(comments):
if comments_str:
assert comments_str in s3_client.user_agent_prefix
assert comments_str in s3_client._client.user_agent_prefix


@given(
part_size=integers(min_value=5 * 1024, max_value=5 * 1024 * 1024),
throughput_target_gbps=floats(min_value=10.0, max_value=100.0),
)
def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
# Part size must be at between 5MiB and 5GiB
part_size = part_size * 1024
s3_client = S3Client(
region=TEST_REGION,
s3_connector_client_config=S3ConnectorClientConfig(
part_size=part_size,
throughput_target_gbps=throughput_target_gbps,
),
)
assert s3_client._client.part_size == part_size
assert abs(s3_client._client.throughput_target_gbps - throughput_target_gbps) < 1e-9


def test_s3_client_invalid_part_size_config():
with pytest.raises(
S3Exception,
match="invalid configuration: part size must be at between 5MiB and 5GiB",
):
s3_client = S3Client(
region=TEST_REGION,
s3_connector_client_config=S3ConnectorClientConfig(part_size=1),
)
# The client is lazily initialized
assert s3_client._client.part_size is not None
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._logger_patch import TRACE as LOG_TRACE
from ._logger_patch import _install_trace_logging
from ._mountpoint_s3_client import S3Exception, __version__
from .s3_connector_client_config import S3ConnectorClientConfig
from s3torchconnectorclient.s3_connector_client_config import S3ConnectorClientConfig

_install_trace_logging()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from dataclasses import dataclass


@dataclass
class S3ConnectorClientConfig:
"""A dataclass exposing configurable parameters for the S3 client."""

def __init__(
self, throughput_target_gbps: float = 10.0, part_size: int = 8 * 1024 * 1024
self,
throughput_target_gbps: float = 10.0,
part_size: int = 8 * 1024 * 1024,
):
"""Returns a config wrapper object.
Args:
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach.
You can also use get_recommended_throughput_target_gbps() to get recommended value for your system.
10.0 Gbps by default (may change in future).
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits.
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
8MB by default (may change in future).
"""
self.throughput_target_gbps = throughput_target_gbps
self.part_size = part_size
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from hypothesis import given
from hypothesis.strategies import integers, floats

from s3torchconnectorclient import S3ConnectorClientConfig


def test_default():
config = S3ConnectorClientConfig()
assert config is not None
assert config.part_size == 8 * 1024 * 1024
assert abs(config.throughput_target_gbps - 10.0) < 1e-9


@given(part_size=integers(min_value=1, max_value=1e12))
def test_part_size_setup(part_size: int):
config = S3ConnectorClientConfig(part_size=part_size)
assert config is not None
assert config.part_size == part_size
assert abs(config.throughput_target_gbps - 10.0) < 1e-9


@given(throughput_target_gbps_setup=floats(min_value=1.0, max_value=100.0))
def test_throughput_target_gbps_setup(throughput_target_gbps_setup: float):
config = S3ConnectorClientConfig(
throughput_target_gbps=throughput_target_gbps_setup
)
assert config is not None
assert config.part_size == 8 * 1024 * 1024
assert abs(config.throughput_target_gbps - throughput_target_gbps_setup) < 1e-9

0 comments on commit 5ee1e6b

Please sign in to comment.