From 3b3342715d7618270bcc339a391b571f7c389fd1 Mon Sep 17 00:00:00 2001 From: dnnanuti Date: Wed, 27 Mar 2024 17:35:05 +0000 Subject: [PATCH] Allow usage of unsigned S3 client Update S3ClientConfig to pass in the configuration for allowing unsigned requests. --- .../_s3client/_mock_s3client.py | 1 + .../s3torchconnector/_s3client/_s3client.py | 1 + .../_s3client/s3client_config.py | 1 + .../tst/e2e/test_e2e_s3datasets.py | 17 ++++++++++++++- s3torchconnector/tst/unit/test_s3_client.py | 9 ++++++++ .../_mountpoint_s3_client.pyi | 4 ++-- .../tst/unit/test_mountpoint_s3_client.py | 6 +++--- .../rust/src/mock_client.rs | 10 ++++++--- .../rust/src/mountpoint_s3_client.rs | 21 ++++++++++--------- 9 files changed, 51 insertions(+), 19 deletions(-) diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py index edd92909..e5e1c16b 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py @@ -37,6 +37,7 @@ def __init__( throughput_target_gbps=self.s3client_config.throughput_target_gbps, part_size=self.s3client_config.part_size, user_agent_prefix=self.user_agent_prefix, + usigned=self.s3client_config.unsigned, ) def add_object(self, key: str, data: bytes) -> None: diff --git a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py index d5f4dc8c..3466bafb 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/_s3client.py @@ -77,6 +77,7 @@ def _client_builder(self) -> MountpointS3Client: user_agent_prefix=self._user_agent_prefix, throughput_target_gbps=self._s3client_config.throughput_target_gbps, part_size=self._s3client_config.part_size, + unsigned=self._s3client_config.unsigned, ) def get_object( diff --git a/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py index b12be1c6..d9f519b2 100644 --- a/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py +++ b/s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py @@ -19,3 +19,4 @@ class S3ClientConfig: throughput_target_gbps: float = 10.0 part_size: int = 8 * 1024 * 1024 + unsigned: bool = False diff --git a/s3torchconnector/tst/e2e/test_e2e_s3datasets.py b/s3torchconnector/tst/e2e/test_e2e_s3datasets.py index 61976fea..f0a78c85 100644 --- a/s3torchconnector/tst/e2e/test_e2e_s3datasets.py +++ b/s3torchconnector/tst/e2e/test_e2e_s3datasets.py @@ -9,7 +9,7 @@ from torch.utils.data.datapipes.datapipe import MapDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe -from s3torchconnector import S3IterableDataset, S3MapDataset +from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig def test_s3iterable_dataset_images_10_from_prefix(image_directory): @@ -100,6 +100,21 @@ def test_dataset_unpickled_iterates(image_directory): assert expected == actual +def test_unsigned_client(): + s3_uri = "s3://s3torchconnector-demo/geonet/images/" + region = "us-east-1" + s3_dataset = S3MapDataset.from_prefix( + s3_uri=s3_uri, + region=region, + transform=lambda obj: obj.read(), + s3client_config=S3ClientConfig(unsigned=True), + ) + s3_dataloader = _pytorch_dataloader(s3_dataset) + assert s3_dataloader is not None + assert isinstance(s3_dataloader.dataset, S3MapDataset) + assert len(s3_dataloader) >= 1296 + + def _compare_dataloaders( local_dataloader: DataLoader, s3_dataloader: DataLoader, expected_batch_count: int ): diff --git a/s3torchconnector/tst/unit/test_s3_client.py b/s3torchconnector/tst/unit/test_s3_client.py index 9712b67c..76a78b98 100644 --- a/s3torchconnector/tst/unit/test_s3_client.py +++ b/s3torchconnector/tst/unit/test_s3_client.py @@ -107,6 +107,7 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float): ) assert s3_client._client.part_size == part_size assert s3_client._client.throughput_target_gbps == throughput_target_gbps + assert s3_client._client.unsigned is False @pytest.mark.parametrize( @@ -130,3 +131,11 @@ def test_s3_client_invalid_part_size_config(part_size: int): ) # The client is lazily initialized assert s3_client._client.part_size == part_size + + +def test_unsigned_s3_client(): + s3_client = S3Client( + region=TEST_REGION, + s3client_config=S3ClientConfig(unsigned=True), + ) + assert s3_client._client.unsigned is True diff --git a/s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi b/s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi index c8d1c83d..7ed425ac 100644 --- a/s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi +++ b/s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi @@ -10,7 +10,7 @@ class MountpointS3Client: region: str part_size: int profile: Optional[str] - no_sign_request: bool + unsigned: bool user_agent_prefix: str endpoint: str @@ -21,7 +21,7 @@ class MountpointS3Client: throughput_target_gbps: float = 10.0, part_size: int = 8 * 1024 * 1024, profile: Optional[str] = None, - no_sign_request: bool = False, + unsigned: bool = False, endpoint: Optional[str] = None, ): ... def get_object(self, bucket: str, key: str) -> GetObjectStream: ... diff --git a/s3torchconnectorclient/python/tst/unit/test_mountpoint_s3_client.py b/s3torchconnectorclient/python/tst/unit/test_mountpoint_s3_client.py index 450b69dc..3a0a581e 100644 --- a/s3torchconnectorclient/python/tst/unit/test_mountpoint_s3_client.py +++ b/s3torchconnectorclient/python/tst/unit/test_mountpoint_s3_client.py @@ -243,7 +243,7 @@ def test_put_object_with_storage_class(): # TODO: Add hypothesis setup after aligning on limits def test_mountpoint_client_pickles(): expected_profile = None - expected_no_sign_request = False + expected_unsigned = False expected_region = REGION expected_part_size = 5 * 2**20 expected_throughput_target_gbps = 3.5 @@ -254,7 +254,7 @@ def test_mountpoint_client_pickles(): part_size=expected_part_size, throughput_target_gbps=expected_throughput_target_gbps, profile=expected_profile, - no_sign_request=expected_no_sign_request, + unsigned=expected_unsigned, ) dumped = pickle.dumps(client) loaded = pickle.loads(dumped) @@ -271,7 +271,7 @@ def test_mountpoint_client_pickles(): == expected_throughput_target_gbps ) assert client.profile == loaded.profile == expected_profile - assert client.no_sign_request == loaded.no_sign_request == expected_no_sign_request + assert client.unsigned == loaded.unsigned == expected_unsigned @pytest.mark.parametrize( diff --git a/s3torchconnectorclient/rust/src/mock_client.rs b/s3torchconnectorclient/rust/src/mock_client.rs index cdb2d434..0a13c252 100644 --- a/s3torchconnectorclient/rust/src/mock_client.rs +++ b/s3torchconnectorclient/rust/src/mock_client.rs @@ -26,18 +26,21 @@ pub struct PyMockClient { pub(crate) part_size: usize, #[pyo3(get)] pub(crate) user_agent_prefix: String, + #[pyo3(get)] + pub(crate) unsigned: bool, } #[pymethods] impl PyMockClient { #[new] - #[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))] + #[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false))] pub fn new( region: String, bucket: String, throughput_target_gbps: f64, part_size: usize, user_agent_prefix: String, + unsigned: bool, ) -> PyMockClient { let unordered_list_seed: Option = None; let config = MockClientConfig { bucket, part_size, unordered_list_seed }; @@ -48,7 +51,8 @@ impl PyMockClient { region, throughput_target_gbps, part_size, - user_agent_prefix + user_agent_prefix, + unsigned } } @@ -59,7 +63,7 @@ impl PyMockClient { self.throughput_target_gbps, self.part_size, None, - false, + self.unsigned, self.mock_client.clone(), None, ) diff --git a/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs b/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs index 4a8a8ac0..094b7bdc 100644 --- a/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs +++ b/s3torchconnectorclient/rust/src/mountpoint_s3_client.rs @@ -41,7 +41,7 @@ pub struct MountpointS3Client { #[pyo3(get)] profile: Option, #[pyo3(get)] - no_sign_request: bool, + unsigned: bool, #[pyo3(get)] user_agent_prefix: String, #[pyo3(get)] @@ -53,14 +53,14 @@ pub struct MountpointS3Client { #[pymethods] impl MountpointS3Client { #[new] - #[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, no_sign_request=false, endpoint=None))] + #[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))] pub fn new_s3_client( region: String, user_agent_prefix: String, throughput_target_gbps: f64, part_size: usize, profile: Option, - no_sign_request: bool, + unsigned: bool, endpoint: Option, ) -> PyResult { // TODO: Mountpoint has logic for guessing based on instance type. It may be worth having @@ -72,7 +72,7 @@ impl MountpointS3Client { } else { EndpointConfig::new(®ion).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap()) }; - let auth_config = auth_config(profile.as_deref(), no_sign_request); + let auth_config = auth_config(profile.as_deref(), unsigned); let user_agent_suffix = &format!("{}/{}", build_info::PACKAGE_NAME, build_info::FULL_VERSION); @@ -96,7 +96,7 @@ impl MountpointS3Client { throughput_target_gbps, part_size, profile, - no_sign_request, + unsigned, crt_client, endpoint, )) @@ -154,7 +154,7 @@ impl MountpointS3Client { slf.throughput_target_gbps.to_object(py), slf.part_size.to_object(py), slf.profile.to_object(py), - slf.no_sign_request.to_object(py), + slf.unsigned.to_object(py), slf.endpoint.to_object(py), ]; Ok(PyTuple::new(py, state)) @@ -169,7 +169,8 @@ impl MountpointS3Client { throughput_target_gbps: f64, part_size: usize, profile: Option, - no_sign_request: bool, + // no_sign_request on mountpoint-s3-client + unsigned: bool, client: Arc, endpoint: Option, ) -> Self @@ -183,7 +184,7 @@ impl MountpointS3Client { part_size, region, profile, - no_sign_request, + unsigned, client: Arc::new(MountpointS3ClientInnerImpl::new(client)), user_agent_prefix, endpoint, @@ -192,8 +193,8 @@ impl MountpointS3Client { } } -fn auth_config(profile: Option<&str>, no_sign_request: bool) -> S3ClientAuthConfig { - if no_sign_request { +fn auth_config(profile: Option<&str>, unsigned: bool) -> S3ClientAuthConfig { + if unsigned { S3ClientAuthConfig::NoSigning } else if let Some(profile_name) = profile { S3ClientAuthConfig::Profile(profile_name.to_string())