Skip to content

Commit

Permalink
Delegated all requests to pyoidc library (#132)
Browse files Browse the repository at this point in the history
All requests to OP/IdP are now entrusted to base library.
  • Loading branch information
infohash authored Jun 5, 2022
1 parent 5e66a38 commit 7d126dd
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 192 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
author_email='[email protected]',
description='Flask extension for OpenID Connect authentication.',
install_requires=[
'oic>=1.2.1',
'oic>=1.4.0',
'Flask',
'requests',
'importlib_resources'
Expand Down
3 changes: 2 additions & 1 deletion src/flask_pyoidc/auth_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def process_auth_response(self, auth_response, auth_request):
refresh_token = None # but never refresh token

if 'code' in auth_response:
token_resp = self._client.exchange_authorization_code(auth_response['code'])
token_resp = self._client.exchange_authorization_code(auth_response['code'],
auth_response['state'])
if token_resp:
if 'error' in token_resp:
raise AuthResponseErrorResponseError(token_resp.to_dict())
Expand Down
6 changes: 6 additions & 0 deletions src/flask_pyoidc/message_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from oic.oauth2.message import AccessTokenResponse, CCAccessTokenRequest, MessageTuple, OauthMessageFactory


class CCMessageFactory(OauthMessageFactory):
"""Client Credential Request Factory."""
token_endpoint = MessageTuple(CCAccessTokenRequest, AccessTokenResponse)
20 changes: 10 additions & 10 deletions src/flask_pyoidc/provider_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import collections.abc
import logging

from oic.oic import Client
import requests
from oic.oic import Client
from oic.utils.settings import ClientSettings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -169,17 +170,16 @@ def __init__(self,
self.userinfo_endpoint_method = userinfo_http_method
self.auth_request_params = auth_request_params or {}
self.session_refresh_interval_seconds = session_refresh_interval_seconds
# For session persistence
self.client_settings = ClientSettings(timeout=self.DEFAULT_REQUEST_TIMEOUT,
requests_session=requests_session or requests.Session())

self.requests_session = requests_session or requests.Session()

def ensure_provider_metadata(self):
def ensure_provider_metadata(self, client: Client):
if not self._provider_metadata:
resp = self.requests_session \
.get(self._issuer + '/.well-known/openid-configuration',
timeout=self.DEFAULT_REQUEST_TIMEOUT)
logger.debug('Received discovery response: ' + resp.text)
resp = client.provider_config(self._issuer)
logger.debug(f'Received discovery response: {resp.to_dict()}')

self._provider_metadata = ProviderMetadata(**resp.json())
self._provider_metadata = ProviderMetadata(**resp.to_dict())

return self._provider_metadata

Expand All @@ -200,8 +200,8 @@ def register_client(self, client: Client):
registration_response = client.register(
url=self._provider_metadata['registration_endpoint'],
**registration_request)
logger.info('Received registration response.')
self._client_metadata = ClientMetadata(
**registration_response.to_dict())
logger.debug('Received registration response: client_id=' + self._client_metadata['client_id'])

return self._client_metadata
Loading

0 comments on commit 7d126dd

Please sign in to comment.