Skip to content

Commit

Permalink
feat(auth): introduce OtfCognito class to handle device key in auth p…
Browse files Browse the repository at this point in the history
…arams

Refactor OtfAuth to use OtfCognito instead of Cognito for better handling
of device keys during token refresh. This change ensures that the
renew_access_token method sets a new access token using cached refresh
token and device metadata, preventing NOT_AUTHORIZED errors.
  • Loading branch information
NodeJSmith committed Jan 7, 2025
1 parent 708e6a3 commit bc8428c
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions src/otf_api/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otf_api.auth.utils import CognitoTokens

if typing.TYPE_CHECKING:
from mypy_boto3_cognito_idp import CognitoIdentityProviderClient
from mypy_boto3_cognito_idp.type_defs import AuthenticationResultTypeTypeDef

LOGGER = getLogger(__name__)
Expand All @@ -23,10 +24,31 @@
BOTO_CONFIG = Config(region_name=REGION, signature_version=UNSIGNED)


class OtfCognito(Cognito):
"""A subclass of the pycognito Cognito class that adds the device_key to the auth_params. Without this
being set the renew_access_token call will always fail with NOT_AUTHORIZED."""

auth: "OTF_AUTH_TYPE"
client: "CognitoIdentityProviderClient"

def renew_access_token(self) -> None:
"""Sets a new access token on the User using the cached refresh token and device metadata."""
auth_params = {"REFRESH_TOKEN": self.refresh_token}
self._add_secret_hash(auth_params, "SECRET_HASH")

if dd := self.auth.config.dd_cache.get_cached_data():
auth_params["DEVICE_KEY"] = dd["device_key"]

refresh_response = self.client.initiate_auth(
ClientId=self.client_id, AuthFlow="REFRESH_TOKEN_AUTH", AuthParameters=auth_params
)
self._set_tokens(refresh_response)


class OtfAuth:
auth_type: ClassVar[Literal["basic", "token", "cognito"]]

cognito: Cognito
cognito: OtfCognito
config: OtfAuthConfig

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -78,7 +100,7 @@ def create(
id_token: str | None = None,
access_token: str | None = None,
refresh_token: str | None = None,
cognito: Cognito | None = None,
cognito: OtfCognito | None = None,
config: OtfAuthConfig | None = None,
) -> "OTF_AUTH_TYPE":
"""Create an authentication object.
Expand Down Expand Up @@ -117,7 +139,7 @@ def authenticate(self) -> None:
raise NotImplementedError

def setup_cognito(self, tokens: CognitoTokens) -> None:
self.cognito = Cognito(
self.cognito = OtfCognito(
USER_POOL_ID,
CLIENT_ID,
access_token=tokens.access_token,
Expand All @@ -141,6 +163,10 @@ def validate_cognito_tokens(self) -> None:
}
self.config.token_cache.write_to_cache(tokens)

# ensure the cognito instance has the auth object
# we'll need this for the device key during refresh
self.cognito.auth = self


@attrs.define
class OtfBasicAuth(OtfAuth):
Expand All @@ -161,7 +187,7 @@ def get_awssrp(self) -> AWSSRP:
"client": boto3.client("cognito-idp", config=BOTO_CONFIG),
}

dd = self.config.dd_cache.get_cached_data() if self.config.cache_device_data else {}
dd = self.config.dd_cache.get_cached_data()

kwargs = kwargs | dd | {"username": self.username, "password": self.password}

Expand Down Expand Up @@ -191,7 +217,7 @@ def handle_device_setup(self, tokens: CognitoTokens) -> None:
Args:
tokens (dict): The tokens from the AWS SRP instance.
"""
if not self.config.cache_device_data or self.config.dd_cache.get_cached_data():
if self.config.dd_cache.get_cached_data():
LOGGER.debug("Skipping device setup")

try:
Expand Down Expand Up @@ -258,7 +284,7 @@ class OtfTokenAuth(OtfAuth):
auth_config: OtfAuthConfig = attrs.field(factory=OtfAuthConfig)

def authenticate(self) -> None:
dd = self.auth_config.dd_cache.get_cached_data() if self.auth_config.cache_device_data else {}
dd = self.auth_config.dd_cache.get_cached_data()
dd.pop("device_password", None) # remove device password, not attribute of CognitoTokens
tokens = CognitoTokens(
access_token=self.access_token, id_token=self.id_token, refresh_token=self.refresh_token, **dd
Expand All @@ -270,7 +296,7 @@ def authenticate(self) -> None:
class OtfCognitoAuth(OtfAuth):
auth_type: ClassVar[Literal["basic", "token", "cognito"]] = "cognito"

cognito: Cognito
cognito: OtfCognito
auth_config: OtfAuthConfig = attrs.field(factory=OtfAuthConfig)

def authenticate(self) -> None:
Expand Down

0 comments on commit bc8428c

Please sign in to comment.