Skip to content

Commit

Permalink
Support AWS IAM roles as authentication method (#691)
Browse files Browse the repository at this point in the history
  • Loading branch information
JBOClara authored Jan 4, 2024
1 parent 7a57204 commit 50e1dad
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 17 deletions.
41 changes: 27 additions & 14 deletions medusa/storage/s3_base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, access_key_id, secret_access_key, region):
self.region = region

def __repr__(self):
if len(self.access_key_id) > 0:
if self.access_key_id and len(self.access_key_id) > 0:
key = f"{self.access_key_id[0]}..{self.access_key_id[-1]}"
else:
key = "None"
Expand Down Expand Up @@ -136,13 +136,20 @@ def connect(self):
tcp_keepalive=True,
max_pool_connections=max_pool_size
)
self.s3_client = boto3.client(
's3',
config=boto_config,
aws_access_key_id=self.credentials.access_key_id,
aws_secret_access_key=self.credentials.secret_access_key,
**self.connection_extra_args
)
if self.credentials.access_key_id is not None:
self.s3_client = boto3.client(
's3',
config=boto_config,
aws_access_key_id=self.credentials.access_key_id,
aws_secret_access_key=self.credentials.secret_access_key,
**self.connection_extra_args
)
else:
self.s3_client = boto3.client(
's3',
config=boto_config,
**self.connection_extra_args
)

def disconnect(self):
logging.debug('Disconnecting from S3...')
Expand Down Expand Up @@ -209,12 +216,18 @@ def _consolidate_credentials(config) -> CensoredCredentials:
))
session.set_config_variable('credentials_file', config.key_file)

boto_credentials = session.get_credentials()
return CensoredCredentials(
access_key_id=boto_credentials.access_key,
secret_access_key=boto_credentials.secret_key,
region=session.get_config_variable('region'),
)
boto_credentials = session.get_credentials()
return CensoredCredentials(
access_key_id=boto_credentials.access_key,
secret_access_key=boto_credentials.secret_key,
region=session.get_config_variable('region'),
)
else:
return CensoredCredentials(
access_key_id=None,
secret_access_key=None,
region=session.get_config_variable('region'),
)

@staticmethod
def _region_from_provider_name(provider_name: str) -> str:
Expand Down
82 changes: 79 additions & 3 deletions tests/storage/s3_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@
import unittest
import tempfile

from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, mock_open
import botocore.utils

from medusa.storage.s3_base_storage import S3BaseStorage
from tests.storage.abstract_storage_test import AttributeDict


class S3StorageTest(unittest.TestCase):
original_call = None

def setUp(self):
self.original_call = botocore.utils.FileWebIdentityTokenLoader.__call__

def tearDown(self):
botocore.utils.FileWebIdentityTokenLoader.__call__ = self.original_call

def test_legacy_provider_region_replacement(self):
assert S3BaseStorage._region_from_provider_name("s3_us_west_oregon") == "us-west-2"
Expand All @@ -41,6 +49,16 @@ def test_credentials_from_metadata(self):
del(os.environ['AWS_SECRET_ACCESS_KEY'])
if os.environ.get('AWS_PROFILE', None):
del(os.environ['AWS_PROFILE'])
if os.environ.get('AWS_STS_REGIONAL_ENDPOINTS', None):
del(os.environ['AWS_STS_REGIONAL_ENDPOINTS'])
if os.environ.get('AWS_DEFAULT_REGION', None):
del(os.environ['AWS_DEFAULT_REGION'])
if os.environ.get('AWS_REGION', None):
del(os.environ['AWS_REGION'])
if os.environ.get('AWS_ROLE_ARN', None):
del(os.environ['AWS_ROLE_ARN'])
if os.environ.get('AWS_WEB_IDENTITY_TOKEN_FILE', None):
del(os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'])

self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))
Expand Down Expand Up @@ -131,8 +149,8 @@ def test_credentials_from_everything(self):
credentials = S3BaseStorage._consolidate_credentials(config)
self.assertEqual('key-from-file', credentials.access_key_id)

del (os.environ['AWS_ACCESS_KEY_ID'])
del (os.environ['AWS_SECRET_ACCESS_KEY'])
del(os.environ['AWS_ACCESS_KEY_ID'])
del(os.environ['AWS_SECRET_ACCESS_KEY'])

def test_credentials_with_default_region(self):
credentials_file_content = """
Expand Down Expand Up @@ -278,6 +296,50 @@ def test_make_s3_compatible_url_without_secure(self):
s3_storage.connection_extra_args['endpoint_url']
)

def test_assume_role_authentication(self):
with patch('botocore.httpsession.URLLib3Session', new=_make_assume_role_with_web_identity_mock()):
if os.environ.get('AWS_ACCESS_KEY_ID', None):
del(os.environ['AWS_ACCESS_KEY_ID'])
if os.environ.get('AWS_SECRET_ACCESS_KEY', None):
del(os.environ['AWS_SECRET_ACCESS_KEY'])
if os.environ.get('AWS_PROFILE', None):
del(os.environ['AWS_PROFILE'])

self.assertIsNone(os.environ.get('AWS_ACCESS_KEY_ID', None))
self.assertIsNone(os.environ.get('AWS_SECRET_ACCESS_KEY', None))
self.assertIsNone(os.environ.get('AWS_PROFILE', None))

os.environ['AWS_STS_REGIONAL_ENDPOINTS'] = 'regional'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
os.environ['AWS_REGION'] = 'us-east-1'
os.environ['AWS_ROLE_ARN'] = 'arn:aws:iam::123456789012:role/testRole'

os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] = '/var/run/secrets/token'

# Create a mock file with the token
mock_file_content = 'eyJh...'
mock_call = mock_open(read_data=mock_file_content)
config = AttributeDict({
'storage_provider': 's3_us_west_oregon',
'region': 'default',
'key_file': '',
'api_profile': None,
'kms_id': None,
'transfer_max_bandwidth': None,
'bucket_name': 'whatever-bucket',
'secure': 'True',
'host': None,
'port': None,
'concurrent_transfers': '1'
})

# Replace the open function with the mock
with patch('builtins.open', mock_call):
credentials = S3BaseStorage._consolidate_credentials(config)

self.assertEqual(None, credentials.access_key_id)
self.assertEqual(None, credentials.secret_access_key)


def _make_instance_metadata_mock():
# mock a call to the metadata service
Expand All @@ -294,3 +356,17 @@ def _make_instance_metadata_mock():
mock_session = MagicMock()
mock_session.send = mock_send
return mock_session


def _make_assume_role_with_web_identity_mock():
# mock a call to the AssumeRoleWithWebIdentity endpoint
mock_response = MagicMock()
mock_response.status_code = 200
in_one_hour = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
mock_response.text = json.dumps({
"AccessKeyId": 'key-from-assume-role',
"SecretAccessKey": 'secret-from-assume-role',
"SessionToken": 'token-from-assume-role',
"Expiration": in_one_hour.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' # -3 to remove microseconds
})
return MagicMock(return_value=mock_response)

0 comments on commit 50e1dad

Please sign in to comment.