Skip to content

Commit

Permalink
Allow usage of unsigned S3 client
Browse files Browse the repository at this point in the history
Update S3ClientConfig to pass in the configuration
for allowing unsigned requests.
  • Loading branch information
dnnanuti committed Mar 27, 2024
1 parent 06c2312 commit a9ac828
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 19 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Unreleased

### New features
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.


## v1.2.2 (March 22, 2024)

### New features
Expand Down
6 changes: 6 additions & 0 deletions doc/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ Using S3ClientConfig you can set up the following parameters for the underlying
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
Part size must have **values between 5MiB and 5GiB.** Is set by default to **8MiB** (may change in future).

* `unsigned(bool)`: Allows the usage of unsigned clients when accessing public datasets or when other mechanisms are
in place to grant access.

For example this can be passed in like:
```py
from s3torchconnector import S3MapDataset, S3ClientConfig
Expand All @@ -165,6 +168,9 @@ s3_map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION, s3client_c
s3_checkpoint = S3Checkpoint(region=REGION, s3client_config=config)
# Works similarly for Lightning checkpoints.
s3_lightning_checkpoint = S3LightningCheckpoint(region=REGION, s3client_config=config)

# Use an unsigned S3 client
s3_client = S3Client(region=REGION, s3client_config=S3ClientConfig(unsigned=True))
```

**When modifying the default values for these flags, we strongly recommend to run benchmarking to ensure you are not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
usigned=self.s3client_config.unsigned,
)

def add_object(self, key: str, data: bytes) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _client_builder(self) -> MountpointS3Client:
user_agent_prefix=self._user_agent_prefix,
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
unsigned=self._s3client_config.unsigned,
)

def get_object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class S3ClientConfig:

throughput_target_gbps: float = 10.0
part_size: int = 8 * 1024 * 1024
unsigned: bool = False
17 changes: 16 additions & 1 deletion s3torchconnector/tst/e2e/test_e2e_s3datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data.datapipes.datapipe import MapDataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe

from s3torchconnector import S3IterableDataset, S3MapDataset
from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig


def test_s3iterable_dataset_images_10_from_prefix(image_directory):
Expand Down Expand Up @@ -100,6 +100,21 @@ def test_dataset_unpickled_iterates(image_directory):
assert expected == actual


def test_unsigned_client():
s3_uri = "s3://s3torchconnector-demo/geonet/images/"
region = "us-east-1"
s3_dataset = S3MapDataset.from_prefix(
s3_uri=s3_uri,
region=region,
transform=lambda obj: obj.read(),
s3client_config=S3ClientConfig(unsigned=True),
)
s3_dataloader = _pytorch_dataloader(s3_dataset)
assert s3_dataloader is not None
assert isinstance(s3_dataloader.dataset, S3MapDataset)
assert len(s3_dataloader) >= 1296


def _compare_dataloaders(
local_dataloader: DataLoader, s3_dataloader: DataLoader, expected_batch_count: int
):
Expand Down
9 changes: 9 additions & 0 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
)
assert s3_client._client.part_size == part_size
assert s3_client._client.throughput_target_gbps == throughput_target_gbps
assert s3_client._client.unsigned is False


@pytest.mark.parametrize(
Expand All @@ -130,3 +131,11 @@ def test_s3_client_invalid_part_size_config(part_size: int):
)
# The client is lazily initialized
assert s3_client._client.part_size == part_size


def test_unsigned_s3_client():
s3_client = S3Client(
region=TEST_REGION,
s3client_config=S3ClientConfig(unsigned=True),
)
assert s3_client._client.unsigned is True
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MountpointS3Client:
region: str
part_size: int
profile: Optional[str]
no_sign_request: bool
unsigned: bool
user_agent_prefix: str
endpoint: str

Expand All @@ -21,7 +21,7 @@ class MountpointS3Client:
throughput_target_gbps: float = 10.0,
part_size: int = 8 * 1024 * 1024,
profile: Optional[str] = None,
no_sign_request: bool = False,
unsigned: bool = False,
endpoint: Optional[str] = None,
): ...
def get_object(self, bucket: str, key: str) -> GetObjectStream: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_put_object_with_storage_class():
# TODO: Add hypothesis setup after aligning on limits
def test_mountpoint_client_pickles():
expected_profile = None
expected_no_sign_request = False
expected_unsigned = False
expected_region = REGION
expected_part_size = 5 * 2**20
expected_throughput_target_gbps = 3.5
Expand All @@ -254,7 +254,7 @@ def test_mountpoint_client_pickles():
part_size=expected_part_size,
throughput_target_gbps=expected_throughput_target_gbps,
profile=expected_profile,
no_sign_request=expected_no_sign_request,
unsigned=expected_unsigned,
)
dumped = pickle.dumps(client)
loaded = pickle.loads(dumped)
Expand All @@ -271,7 +271,7 @@ def test_mountpoint_client_pickles():
== expected_throughput_target_gbps
)
assert client.profile == loaded.profile == expected_profile
assert client.no_sign_request == loaded.no_sign_request == expected_no_sign_request
assert client.unsigned == loaded.unsigned == expected_unsigned


@pytest.mark.parametrize(
Expand Down
10 changes: 7 additions & 3 deletions s3torchconnectorclient/rust/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@ pub struct PyMockClient {
pub(crate) part_size: usize,
#[pyo3(get)]
pub(crate) user_agent_prefix: String,
#[pyo3(get)]
pub(crate) unsigned: bool,
}

#[pymethods]
impl PyMockClient {
#[new]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false))]
pub fn new(
region: String,
bucket: String,
throughput_target_gbps: f64,
part_size: usize,
user_agent_prefix: String,
unsigned: bool,
) -> PyMockClient {
let unordered_list_seed: Option<u64> = None;
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
Expand All @@ -48,7 +51,8 @@ impl PyMockClient {
region,
throughput_target_gbps,
part_size,
user_agent_prefix
user_agent_prefix,
unsigned
}
}

Expand All @@ -59,7 +63,7 @@ impl PyMockClient {
self.throughput_target_gbps,
self.part_size,
None,
false,
self.unsigned,
self.mock_client.clone(),
None,
)
Expand Down
21 changes: 11 additions & 10 deletions s3torchconnectorclient/rust/src/mountpoint_s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct MountpointS3Client {
#[pyo3(get)]
profile: Option<String>,
#[pyo3(get)]
no_sign_request: bool,
unsigned: bool,
#[pyo3(get)]
user_agent_prefix: String,
#[pyo3(get)]
Expand All @@ -53,14 +53,14 @@ pub struct MountpointS3Client {
#[pymethods]
impl MountpointS3Client {
#[new]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, no_sign_request=false, endpoint=None))]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))]
pub fn new_s3_client(
region: String,
user_agent_prefix: String,
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
unsigned: bool,
endpoint: Option<String>,
) -> PyResult<Self> {
// TODO: Mountpoint has logic for guessing based on instance type. It may be worth having
Expand All @@ -72,7 +72,7 @@ impl MountpointS3Client {
} else {
EndpointConfig::new(&region).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap())
};
let auth_config = auth_config(profile.as_deref(), no_sign_request);
let auth_config = auth_config(profile.as_deref(), unsigned);

let user_agent_suffix =
&format!("{}/{}", build_info::PACKAGE_NAME, build_info::FULL_VERSION);
Expand All @@ -96,7 +96,7 @@ impl MountpointS3Client {
throughput_target_gbps,
part_size,
profile,
no_sign_request,
unsigned,
crt_client,
endpoint,
))
Expand Down Expand Up @@ -154,7 +154,7 @@ impl MountpointS3Client {
slf.throughput_target_gbps.to_object(py),
slf.part_size.to_object(py),
slf.profile.to_object(py),
slf.no_sign_request.to_object(py),
slf.unsigned.to_object(py),
slf.endpoint.to_object(py),
];
Ok(PyTuple::new(py, state))
Expand All @@ -169,7 +169,8 @@ impl MountpointS3Client {
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
// no_sign_request on mountpoint-s3-client
unsigned: bool,
client: Arc<Client>,
endpoint: Option<String>,
) -> Self
Expand All @@ -183,7 +184,7 @@ impl MountpointS3Client {
part_size,
region,
profile,
no_sign_request,
unsigned,
client: Arc::new(MountpointS3ClientInnerImpl::new(client)),
user_agent_prefix,
endpoint,
Expand All @@ -192,8 +193,8 @@ impl MountpointS3Client {
}
}

fn auth_config(profile: Option<&str>, no_sign_request: bool) -> S3ClientAuthConfig {
if no_sign_request {
fn auth_config(profile: Option<&str>, unsigned: bool) -> S3ClientAuthConfig {
if unsigned {
S3ClientAuthConfig::NoSigning
} else if let Some(profile_name) = profile {
S3ClientAuthConfig::Profile(profile_name.to_string())
Expand Down

0 comments on commit a9ac828

Please sign in to comment.