diff --git a/medusa/storage/s3_base_storage.py b/medusa/storage/s3_base_storage.py index 4ee96de2..8e5b51e1 100644 --- a/medusa/storage/s3_base_storage.py +++ b/medusa/storage/s3_base_storage.py @@ -54,7 +54,7 @@ def __init__(self, access_key_id, secret_access_key, region): self.region = region def __repr__(self): - if len(self.access_key_id) > 0: + if self.access_key_id and len(self.access_key_id) > 0: key = f"{self.access_key_id[0]}..{self.access_key_id[-1]}" else: key = "None" @@ -136,13 +136,20 @@ def connect(self): tcp_keepalive=True, max_pool_connections=max_pool_size ) - self.s3_client = boto3.client( - 's3', - config=boto_config, - aws_access_key_id=self.credentials.access_key_id, - aws_secret_access_key=self.credentials.secret_access_key, - **self.connection_extra_args - ) + if self.credentials.access_key_id is not None: + self.s3_client = boto3.client( + 's3', + config=boto_config, + aws_access_key_id=self.credentials.access_key_id, + aws_secret_access_key=self.credentials.secret_access_key, + **self.connection_extra_args + ) + else: + self.s3_client = boto3.client( + 's3', + config=boto_config, + **self.connection_extra_args + ) def disconnect(self): logging.debug('Disconnecting from S3...') @@ -209,12 +216,18 @@ def _consolidate_credentials(config) -> CensoredCredentials: )) session.set_config_variable('credentials_file', config.key_file) - boto_credentials = session.get_credentials() - return CensoredCredentials( - access_key_id=boto_credentials.access_key, - secret_access_key=boto_credentials.secret_key, - region=session.get_config_variable('region'), - ) + boto_credentials = session.get_credentials() + return CensoredCredentials( + access_key_id=boto_credentials.access_key, + secret_access_key=boto_credentials.secret_key, + region=session.get_config_variable('region'), + ) + else: + return CensoredCredentials( + access_key_id=None, + secret_access_key=None, + region=session.get_config_variable('region'), + ) @staticmethod def _region_from_provider_name(provider_name: str) -> str: diff --git a/tests/storage/s3_storage_test.py b/tests/storage/s3_storage_test.py index 11a5daaa..0ae07198 100644 --- a/tests/storage/s3_storage_test.py +++ b/tests/storage/s3_storage_test.py @@ -18,13 +18,21 @@ import unittest import tempfile -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, mock_open +import botocore.utils from medusa.storage.s3_base_storage import S3BaseStorage from tests.storage.abstract_storage_test import AttributeDict class S3StorageTest(unittest.TestCase): + original_call = None + + def setUp(self): + self.original_call = botocore.utils.FileWebIdentityTokenLoader.__call__ + + def tearDown(self): + botocore.utils.FileWebIdentityTokenLoader.__call__ = self.original_call def test_legacy_provider_region_replacement(self): assert S3BaseStorage._region_from_provider_name("s3_us_west_oregon") == "us-west-2" @@ -41,6 +49,16 @@ def test_credentials_from_metadata(self): del(os.environ['AWS_SECRET_ACCESS_KEY']) if os.environ.get('AWS_PROFILE', None): del(os.environ['AWS_PROFILE']) + if os.environ.get('AWS_STS_REGIONAL_ENDPOINTS', None): + del(os.environ['AWS_STS_REGIONAL_ENDPOINTS']) + if os.environ.get('AWS_DEFAULT_REGION', None): + del(os.environ['AWS_DEFAULT_REGION']) + if os.environ.get('AWS_REGION', None): + del(os.environ['AWS_REGION']) + if os.environ.get('AWS_ROLE_ARN', None): + del(os.environ['AWS_ROLE_ARN']) + if os.environ.get('AWS_WEB_IDENTITY_TOKEN_FILE', None): + del(os.environ['AWS_WEB_IDENTITY_TOKEN_FILE']) self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None)) self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None)) @@ -131,8 +149,8 @@ def test_credentials_from_everything(self): credentials = S3BaseStorage._consolidate_credentials(config) self.assertEqual('key-from-file', credentials.access_key_id) - del (os.environ['AWS_ACCESS_KEY_ID']) - del (os.environ['AWS_SECRET_ACCESS_KEY']) + del(os.environ['AWS_ACCESS_KEY_ID']) + del(os.environ['AWS_SECRET_ACCESS_KEY']) def test_credentials_with_default_region(self): credentials_file_content = """ @@ -278,6 +296,50 @@ def test_make_s3_compatible_url_without_secure(self): s3_storage.connection_extra_args['endpoint_url'] ) + def test_assume_role_authentication(self): + with patch('botocore.httpsession.URLLib3Session', new=_make_assume_role_with_web_identity_mock()): + if os.environ.get('AWS_ACCESS_KEY_ID', None): + del(os.environ['AWS_ACCESS_KEY_ID']) + if os.environ.get('AWS_SECRET_ACCESS_KEY', None): + del(os.environ['AWS_SECRET_ACCESS_KEY']) + if os.environ.get('AWS_PROFILE', None): + del(os.environ['AWS_PROFILE']) + + self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None)) + self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None)) + self.assertIsNone(os.environ.get('AWS_PROFILE', None)) + + os.environ['AWS_STS_REGIONAL_ENDPOINTS'] = 'regional' + os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' + os.environ['AWS_REGION'] = 'us-east-1' + os.environ['AWS_ROLE_ARN'] = 'arn:aws:iam::123456789012:role/testRole' + + os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] = '/var/run/secrets/token' + + # Create a mock file with the token + mock_file_content = 'eyJh...' + mock_call = mock_open(read_data=mock_file_content) + config = AttributeDict({ + 'storage_provider': 's3_us_west_oregon', + 'region': 'default', + 'key_file': '', + 'api_profile': None, + 'kms_id': None, + 'transfer_max_bandwidth': None, + 'bucket_name': 'whatever-bucket', + 'secure': 'True', + 'host': None, + 'port': None, + 'concurrent_transfers': '1' + }) + + # Replace the open function with the mock + with patch('builtins.open', mock_call): + credentials = S3BaseStorage._consolidate_credentials(config) + + self.assertEqual(None, credentials.access_key_id) + self.assertEqual(None, credentials.secret_access_key) + def _make_instance_metadata_mock(): # mock a call to the metadata service @@ -294,3 +356,17 @@ def _make_instance_metadata_mock(): mock_session = MagicMock() mock_session.send = mock_send return mock_session + + +def _make_assume_role_with_web_identity_mock(): + # mock a call to the AssumeRoleWithWebIdentity endpoint + mock_response = MagicMock() + mock_response.status_code = 200 + in_one_hour = datetime.datetime.utcnow() + datetime.timedelta(hours=1) + mock_response.text = json.dumps({ + "AccessKeyId": 'key-from-assume-role', + "SecretAccessKey": 'secret-from-assume-role', + "SessionToken": 'token-from-assume-role', + "Expiration": in_one_hour.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' # -3 to remove microseconds + }) + return MagicMock(return_value=mock_response)