Skip to content

Commit

Permalink
chore: add test assume role
Browse files Browse the repository at this point in the history
  • Loading branch information
JBOClara committed Jan 4, 2024
1 parent 8090889 commit 9f761c8
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 149 deletions.
1 change: 0 additions & 1 deletion medusa/storage/s3_base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def _consolidate_credentials(config) -> CensoredCredentials:
region=session.get_config_variable('region'),
)


@staticmethod
def _region_from_provider_name(provider_name: str) -> str:
if provider_name.upper() in LIBCLOUD_REGION_NAME_MAP.keys():
Expand Down
264 changes: 116 additions & 148 deletions tests/storage/s3_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import os
import unittest
import tempfile
import pytest

from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, mock_open
import botocore.utils

from medusa.storage.s3_base_storage import S3BaseStorage
Expand All @@ -29,6 +28,14 @@
class S3StorageTest(unittest.TestCase):
original_call = None

def setUp(self):
print("setting up botocore mock")
self.original_call = botocore.utils.FileWebIdentityTokenLoader.__call__

def tearDown(self):
print("tearing down botocore mock")
botocore.utils.FileWebIdentityTokenLoader.__call__ = self.original_call

def test_legacy_provider_region_replacement(self):
assert (
S3BaseStorage._region_from_provider_name("s3_us_west_oregon") == "us-west-2"
Expand Down Expand Up @@ -163,15 +170,13 @@ def test_credentials_with_default_region(self):
self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))

config = AttributeDict(
{
"api_profile": "default",
"region": "default",
"storage_provider": "s3_us_west_oregon",
"key_file": credentials_file.name,
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"api_profile": "default",
"region": "default",
"storage_provider": "s3_us_west_oregon",
"key_file": credentials_file.name,
"concurrent_transfers": "1",
})

credentials = S3BaseStorage._consolidate_credentials(config)
self.assertEqual("key-from-file", credentials.access_key_id)
Expand All @@ -191,15 +196,13 @@ def test_credentials_with_default_region_and_s3_compatible_storage(self):
self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))

config = AttributeDict(
{
"api_profile": "default",
"region": "default",
"storage_provider": "s3_compatible",
"key_file": credentials_file.name,
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"api_profile": "default",
"region": "default",
"storage_provider": "s3_compatible",
"key_file": credentials_file.name,
"concurrent_transfers": "1",
})

credentials = S3BaseStorage._consolidate_credentials(config)
self.assertEqual("key-from-file", credentials.access_key_id)
Expand All @@ -209,43 +212,39 @@ def test_credentials_with_default_region_and_s3_compatible_storage(self):
def test_make_s3_url(self):
with patch('botocore.httpsession.URLLib3Session', return_value=_make_instance_metadata_mock()):
with tempfile.NamedTemporaryFile() as empty_file:
config = AttributeDict(
{
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": None,
"port": None,
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": None,
"port": None,
"concurrent_transfers": "1",
})
s3_storage = S3BaseStorage(config)
# there are no extra connection args when connecting to regular S3
self.assertEqual(dict(), s3_storage.connection_extra_args)

def test_make_s3_url_without_secure(self):
with patch('botocore.httpsession.URLLib3Session', return_value=_make_instance_metadata_mock()):
with tempfile.NamedTemporaryFile() as empty_file:
config = AttributeDict(
{
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "False",
"host": None,
"port": None,
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "False",
"host": None,
"port": None,
"concurrent_transfers": "1",
})
s3_storage = S3BaseStorage(config)
# again, no extra connection args when connecting to regular S3
# we can't even disable HTTPS
Expand All @@ -254,21 +253,19 @@ def test_make_s3_url_without_secure(self):
def test_make_s3_compatible_url(self):
with patch('botocore.httpsession.URLLib3Session', return_value=_make_instance_metadata_mock()):
with tempfile.NamedTemporaryFile() as empty_file:
config = AttributeDict(
{
"storage_provider": "s3_compatible",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": "s3.example.com",
"port": "443",
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"storage_provider": "s3_compatible",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": "s3.example.com",
"port": "443",
"concurrent_transfers": "1",
})
s3_storage = S3BaseStorage(config)
self.assertEqual(
"https://s3.example.com:443",
Expand All @@ -278,108 +275,81 @@ def test_make_s3_compatible_url(self):
def test_make_s3_compatible_url_without_secure(self):
with patch('botocore.httpsession.URLLib3Session', return_value=_make_instance_metadata_mock()):
with tempfile.NamedTemporaryFile() as empty_file:
config = AttributeDict(
{
"storage_provider": "s3_compatible",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "False",
"host": "s3.example.com",
"port": "8080",
"concurrent_transfers": "1",
}
)
config = AttributeDict({
"storage_provider": "s3_compatible",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "False",
"host": "s3.example.com",
"port": "8080",
"concurrent_transfers": "1",
})
s3_storage = S3BaseStorage(config)
self.assertEqual(
"http://s3.example.com:8080",
s3_storage.connection_extra_args["endpoint_url"],
)

def test_assume_role_authentication(self):
with patch('botocore.httpsession.URLLib3Session.send', new=_make_assume_role_with_web_identity_mock()):
with tempfile.NamedTemporaryFile() as empty_file:
if os.environ.get('AWS_ACCESS_KEY_ID', None):
del(os.environ['AWS_ACCESS_KEY_ID'])
if os.environ.get('AWS_SECRET_ACCESS_KEY', None):
del(os.environ['AWS_SECRET_ACCESS_KEY'])
if os.environ.get('AWS_PROFILE', None):
del(os.environ['AWS_PROFILE'])
with patch('botocore.httpsession.URLLib3Session', new=_make_assume_role_with_web_identity_mock()):
if os.environ.get('AWS_ACCESS_KEY_ID', None):
del(os.environ['AWS_ACCESS_KEY_ID'])
if os.environ.get('AWS_SECRET_ACCESS_KEY', None):
del(os.environ['AWS_SECRET_ACCESS_KEY'])
if os.environ.get('AWS_PROFILE', None):
del(os.environ['AWS_PROFILE'])

self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))
self.assertIsNone(os.environ.get('AWS_PROFILE', None))
self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))
self.assertIsNone(os.environ.get('AWS_PROFILE', None))

os.environ['AWS_STS_REGIONAL_ENDPOINTS'] = 'regional'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
os.environ['AWS_REGION'] = 'us-east-1'
os.environ['AWS_ROLE_ARN'] = 'arn:aws:iam::123456789012:role/testRole'

# Set AWS_CONFIG_FILE to an empty temporary file
os.environ['AWS_CONFIG_FILE'] = empty_file.name

os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] = '/var/run/secrets/token'


# Create a mock file with the token
mock_file_content = 'eyJh...'
def mock_call(self):
if self._web_identity_token_path == "/var/run/secrets/token":
return mock_file_content
else:
return self.original_call(self)

config = AttributeDict(
{
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": empty_file.name,
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": None,
"port": None,
"concurrent_transfers": "1"
}
)
os.environ['AWS_STS_REGIONAL_ENDPOINTS'] = 'regional'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
os.environ['AWS_REGION'] = 'us-east-1'
os.environ['AWS_ROLE_ARN'] = 'arn:aws:iam::123456789012:role/testRole'

# Replace the open function with the mock
with patch.object(botocore.utils.FileWebIdentityTokenLoader, '__call__', new=mock_call):
credentials = S3BaseStorage._consolidate_credentials(config)
os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] = '/var/run/secrets/token'

self.assertEqual("key-from-assume-role", credentials.access_key_id)
self.assertEqual(
"secret-from-assume-role", credentials.secret_access_key
)
self.assertEqual("token-from-assume-role", credentials.session_token)
# Create a mock file with the token
mock_file_content = 'eyJh...'
mock_call = mock_open(read_data=mock_file_content)
config = AttributeDict({
"storage_provider": "s3_us_west_oregon",
"region": "default",
"key_file": "",
"api_profile": None,
"kms_id": None,
"transfer_max_bandwidth": None,
"bucket_name": "whatever-bucket",
"secure": "True",
"host": None,
"port": None,
"concurrent_transfers": "1"
})

# Replace the open function with the mock
with patch('builtins.open', mock_call):
credentials = S3BaseStorage._consolidate_credentials(config)

@pytest.fixture(autouse=True)
def run_around_tests():
print("setting up AAAAAAAAAAAAAAAAAA")
self.original_call = botocore.utils.FileWebIdentityTokenLoader.__call__
yield
botocore.utils.FileWebIdentityTokenLoader.__call__ = self.original_call
del(os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'])
self.assertEqual(None, credentials.access_key_id)
self.assertEqual(None, credentials.secret_access_key)

def _make_instance_metadata_mock():
# mock a call to the metadata service
mock_response = MagicMock()
mock_response.status_code = 200
in_one_hour = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
mock_response.text = json.dumps(
{
mock_response.text = json.dumps({
"AccessKeyId": "key-from-instance-metadata",
"SecretAccessKey": "secret-from-instance-metadata",
"Token": "token-from-metadata",
"Expiration": in_one_hour.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z", # -3 to remove microseconds
}
)
})
mock_send = MagicMock(return_value=mock_response)
mock_session = MagicMock()
mock_session.send = mock_send
Expand All @@ -391,15 +361,13 @@ def _make_assume_role_with_web_identity_mock():
mock_response = MagicMock()
mock_response.status_code = 200
in_one_hour = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
mock_response.text = json.dumps(
{
mock_response.text = json.dumps({
"Credentials": {
"AccessKeyId": "key-from-assume-role",
"SecretAccessKey": "secret-from-assume-role",
"SessionToken": "token-from-assume-role",
"Expiration": in_one_hour.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z", # -3 to remove microseconds
}
}
)
})
return MagicMock(return_value=mock_response)

0 comments on commit 9f761c8

Please sign in to comment.