Skip to content

Commit

Permalink
Allow usage of unsigned S3 client
Browse files Browse the repository at this point in the history
Update S3ClientConfig to pass in the configuration
for allowing unsigned requests.
  • Loading branch information
dnnanuti committed Mar 27, 2024
1 parent 06c2312 commit 3b33427
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
usigned=self.s3client_config.unsigned,
)

def add_object(self, key: str, data: bytes) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _client_builder(self) -> MountpointS3Client:
user_agent_prefix=self._user_agent_prefix,
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
unsigned=self._s3client_config.unsigned,
)

def get_object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class S3ClientConfig:

throughput_target_gbps: float = 10.0
part_size: int = 8 * 1024 * 1024
unsigned: bool = False
17 changes: 16 additions & 1 deletion s3torchconnector/tst/e2e/test_e2e_s3datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data.datapipes.datapipe import MapDataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe

from s3torchconnector import S3IterableDataset, S3MapDataset
from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig


def test_s3iterable_dataset_images_10_from_prefix(image_directory):
Expand Down Expand Up @@ -100,6 +100,21 @@ def test_dataset_unpickled_iterates(image_directory):
assert expected == actual


def test_unsigned_client():
s3_uri = "s3://s3torchconnector-demo/geonet/images/"
region = "us-east-1"
s3_dataset = S3MapDataset.from_prefix(
s3_uri=s3_uri,
region=region,
transform=lambda obj: obj.read(),
s3client_config=S3ClientConfig(unsigned=True),
)
s3_dataloader = _pytorch_dataloader(s3_dataset)
assert s3_dataloader is not None
assert isinstance(s3_dataloader.dataset, S3MapDataset)
assert len(s3_dataloader) >= 1296


def _compare_dataloaders(
local_dataloader: DataLoader, s3_dataloader: DataLoader, expected_batch_count: int
):
Expand Down
9 changes: 9 additions & 0 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
)
assert s3_client._client.part_size == part_size
assert s3_client._client.throughput_target_gbps == throughput_target_gbps
assert s3_client._client.unsigned is False


@pytest.mark.parametrize(
Expand All @@ -130,3 +131,11 @@ def test_s3_client_invalid_part_size_config(part_size: int):
)
# The client is lazily initialized
assert s3_client._client.part_size == part_size


def test_unsigned_s3_client():
s3_client = S3Client(
region=TEST_REGION,
s3client_config=S3ClientConfig(unsigned=True),
)
assert s3_client._client.unsigned is True
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MountpointS3Client:
region: str
part_size: int
profile: Optional[str]
no_sign_request: bool
unsigned: bool
user_agent_prefix: str
endpoint: str

Expand All @@ -21,7 +21,7 @@ class MountpointS3Client:
throughput_target_gbps: float = 10.0,
part_size: int = 8 * 1024 * 1024,
profile: Optional[str] = None,
no_sign_request: bool = False,
unsigned: bool = False,
endpoint: Optional[str] = None,
): ...
def get_object(self, bucket: str, key: str) -> GetObjectStream: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_put_object_with_storage_class():
# TODO: Add hypothesis setup after aligning on limits
def test_mountpoint_client_pickles():
expected_profile = None
expected_no_sign_request = False
expected_unsigned = False
expected_region = REGION
expected_part_size = 5 * 2**20
expected_throughput_target_gbps = 3.5
Expand All @@ -254,7 +254,7 @@ def test_mountpoint_client_pickles():
part_size=expected_part_size,
throughput_target_gbps=expected_throughput_target_gbps,
profile=expected_profile,
no_sign_request=expected_no_sign_request,
unsigned=expected_unsigned,
)
dumped = pickle.dumps(client)
loaded = pickle.loads(dumped)
Expand All @@ -271,7 +271,7 @@ def test_mountpoint_client_pickles():
== expected_throughput_target_gbps
)
assert client.profile == loaded.profile == expected_profile
assert client.no_sign_request == loaded.no_sign_request == expected_no_sign_request
assert client.unsigned == loaded.unsigned == expected_unsigned


@pytest.mark.parametrize(
Expand Down
10 changes: 7 additions & 3 deletions s3torchconnectorclient/rust/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@ pub struct PyMockClient {
pub(crate) part_size: usize,
#[pyo3(get)]
pub(crate) user_agent_prefix: String,
#[pyo3(get)]
pub(crate) unsigned: bool,
}

#[pymethods]
impl PyMockClient {
#[new]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false))]
pub fn new(
region: String,
bucket: String,
throughput_target_gbps: f64,
part_size: usize,
user_agent_prefix: String,
unsigned: bool,
) -> PyMockClient {
let unordered_list_seed: Option<u64> = None;
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
Expand All @@ -48,7 +51,8 @@ impl PyMockClient {
region,
throughput_target_gbps,
part_size,
user_agent_prefix
user_agent_prefix,
unsigned
}
}

Expand All @@ -59,7 +63,7 @@ impl PyMockClient {
self.throughput_target_gbps,
self.part_size,
None,
false,
self.unsigned,
self.mock_client.clone(),
None,
)
Expand Down
21 changes: 11 additions & 10 deletions s3torchconnectorclient/rust/src/mountpoint_s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct MountpointS3Client {
#[pyo3(get)]
profile: Option<String>,
#[pyo3(get)]
no_sign_request: bool,
unsigned: bool,
#[pyo3(get)]
user_agent_prefix: String,
#[pyo3(get)]
Expand All @@ -53,14 +53,14 @@ pub struct MountpointS3Client {
#[pymethods]
impl MountpointS3Client {
#[new]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, no_sign_request=false, endpoint=None))]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))]
pub fn new_s3_client(
region: String,
user_agent_prefix: String,
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
unsigned: bool,
endpoint: Option<String>,
) -> PyResult<Self> {
// TODO: Mountpoint has logic for guessing based on instance type. It may be worth having
Expand All @@ -72,7 +72,7 @@ impl MountpointS3Client {
} else {
EndpointConfig::new(&region).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap())
};
let auth_config = auth_config(profile.as_deref(), no_sign_request);
let auth_config = auth_config(profile.as_deref(), unsigned);

let user_agent_suffix =
&format!("{}/{}", build_info::PACKAGE_NAME, build_info::FULL_VERSION);
Expand All @@ -96,7 +96,7 @@ impl MountpointS3Client {
throughput_target_gbps,
part_size,
profile,
no_sign_request,
unsigned,
crt_client,
endpoint,
))
Expand Down Expand Up @@ -154,7 +154,7 @@ impl MountpointS3Client {
slf.throughput_target_gbps.to_object(py),
slf.part_size.to_object(py),
slf.profile.to_object(py),
slf.no_sign_request.to_object(py),
slf.unsigned.to_object(py),
slf.endpoint.to_object(py),
];
Ok(PyTuple::new(py, state))
Expand All @@ -169,7 +169,8 @@ impl MountpointS3Client {
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
// no_sign_request on mountpoint-s3-client
unsigned: bool,
client: Arc<Client>,
endpoint: Option<String>,
) -> Self
Expand All @@ -183,7 +184,7 @@ impl MountpointS3Client {
part_size,
region,
profile,
no_sign_request,
unsigned,
client: Arc::new(MountpointS3ClientInnerImpl::new(client)),
user_agent_prefix,
endpoint,
Expand All @@ -192,8 +193,8 @@ impl MountpointS3Client {
}
}

fn auth_config(profile: Option<&str>, no_sign_request: bool) -> S3ClientAuthConfig {
if no_sign_request {
fn auth_config(profile: Option<&str>, unsigned: bool) -> S3ClientAuthConfig {
if unsigned {
S3ClientAuthConfig::NoSigning
} else if let Some(profile_name) = profile {
S3ClientAuthConfig::Profile(profile_name.to_string())
Expand Down

0 comments on commit 3b33427

Please sign in to comment.