diff --git a/.github/workflows/test.yaml b/.github/workflows/push.yaml similarity index 87% rename from .github/workflows/test.yaml rename to .github/workflows/push.yaml index d6f44510..58124858 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/push.yaml @@ -20,5 +20,6 @@ jobs: python-version: ${{ matrix.python-version }} - run: just deps - run: just test - - run: just cov + - if: ${{ matrix.python-version }} == 3.12 + run: just cov - run: just lint diff --git a/src/posit/connect/__init__.py b/src/posit/connect/__init__.py index 6eb66c5a..79add3a0 100644 --- a/src/posit/connect/__init__.py +++ b/src/posit/connect/__init__.py @@ -1,10 +1 @@ -from typing import Optional - -from .client import Client - - -def make_client( - api_key: Optional[str] = None, endpoint: Optional[str] = None -) -> Client: - client = Client(api_key=api_key, endpoint=endpoint) - return client +from .client import create_client # noqa diff --git a/src/posit/connect/auth.py b/src/posit/connect/auth.py index 44bdc00e..d4a72d04 100644 --- a/src/posit/connect/auth.py +++ b/src/posit/connect/auth.py @@ -1,11 +1,13 @@ from requests import PreparedRequest from requests.auth import AuthBase +from .config import Config + class Auth(AuthBase): - def __init__(self, key: str) -> None: - self.key = key + def __init__(self, config: Config) -> None: + self._config = config def __call__(self, r: PreparedRequest) -> PreparedRequest: - r.headers["Authorization"] = f"Key {self.key}" + r.headers["Authorization"] = f"Key {self._config.api_key}" return r diff --git a/src/posit/connect/auth_test.py b/src/posit/connect/auth_test.py index 66d4ea8a..59ff231f 100644 --- a/src/posit/connect/auth_test.py +++ b/src/posit/connect/auth_test.py @@ -1,13 +1,15 @@ -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from .auth import Auth class TestAuth: - def test_auth_headers(self): - key = "foobar" - auth = Auth(key=key) + @patch("posit.connect.auth.Config") + def test_auth_headers(self, Config: MagicMock): + config = Config.return_value + config.api_key = "foobar" + auth = Auth(config=config) r = Mock() r.headers = {} auth(r) - assert r.headers == {"Authorization": f"Key {key}"} + assert r.headers == {"Authorization": f"Key {config.api_key}"} diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index 63b20f9b..e0fa2474 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -1,48 +1,34 @@ -import os +from __future__ import annotations +from contextlib import contextmanager from requests import Session -from typing import Optional +from typing import Generator, Optional from . import hooks from .auth import Auth +from .config import Config from .users import Users -def _get_api_key() -> str: - """Gets the API key from the environment variable 'CONNECT_API_KEY'. +@contextmanager +def create_client( + api_key: Optional[str] = None, endpoint: Optional[str] = None +) -> Generator[Client, None, None]: + """Creates a new :class:`Client` instance - Raises: - ValueError: if CONNECT_API_KEY is not set or invalid + Keyword Arguments: + api_key -- an api_key for authentication (default: {None}) + endpoint -- a base api endpoint (url) (default: {None}) Returns: - The API key + A :class:`Client` instance """ - value = os.environ.get("CONNECT_API_KEY") - if value is None or value == "": - raise ValueError( - "Invalid value for 'CONNECT_API_KEY': Must be a non-empty string." - ) - return value - - -def _get_endpoint() -> str: - """Gets the endpoint from the environment variable 'CONNECT_SERVER'. - - The `requests` library uses 'endpoint' instead of 'server'. We will use 'endpoint' from here forward for consistency. - - Raises: - ValueError: if CONNECT_SERVER is not set or invalid. - - Returns: - The endpoint. - """ - value = os.environ.get("CONNECT_SERVER") - if value is None or value == "": - raise ValueError( - "Invalid value for 'CONNECT_SERVER': Must be a non-empty string." - ) - return value + client = Client(api_key=api_key, endpoint=endpoint) + try: + yield client + finally: + del client class Client: @@ -53,9 +39,29 @@ def __init__( api_key: Optional[str] = None, endpoint: Optional[str] = None, ) -> None: - self._api_key = api_key or _get_api_key() - self._endpoint = endpoint or _get_endpoint() - self._session = Session() - self._session.hooks["response"].append(hooks.handle_errors) - self._session.auth = Auth(self._api_key) - self.users = Users(self._endpoint, self._session) + """ + Initialize the Client instance. + + Args: + api_key (str, optional): API key for authentication. Defaults to None. + endpoint (str, optional): API endpoint URL. Defaults to None. + """ + # Create a Config object. + config = Config(api_key=api_key, endpoint=endpoint) + # Create a Session object for making HTTP requests. + session = Session() + # Authenticate the session using the provided Config. + session.auth = Auth(config=config) + # Add error handling hooks to the session. + session.hooks["response"].append(hooks.handle_errors) + + # Initialize the Users instance. + self.users = Users(config=config, session=session) + # Store the Session object. + self._session = session + + def __del__(self): + """ + Close the session when the Client instance is deleted. + """ + self._session.close() diff --git a/src/posit/connect/client_test.py b/src/posit/connect/client_test.py index ba64e366..49ee14f4 100644 --- a/src/posit/connect/client_test.py +++ b/src/posit/connect/client_test.py @@ -1,52 +1,40 @@ -import pytest - from unittest.mock import MagicMock, patch -from .client import Client, _get_api_key, _get_endpoint +from .client import Client, create_client + + +class TestCreateClient: + @patch("posit.connect.client.Client") + def test(self, Client: MagicMock): + api_key = "foobar" + endpoint = "http://foo.bar" + with create_client(api_key=api_key, endpoint=endpoint) as client: + assert client == Client.return_value class TestClient: @patch("posit.connect.client.Users") @patch("posit.connect.client.Session") + @patch("posit.connect.client.Config") @patch("posit.connect.client.Auth") - def test_init(self, Auth: MagicMock, Session: MagicMock, Users: MagicMock): + def test_init( + self, Auth: MagicMock, Config: MagicMock, Session: MagicMock, Users: MagicMock + ): api_key = "foobar" endpoint = "http://foo.bar" - client = Client(api_key=api_key, endpoint=endpoint) - assert client._api_key == api_key - assert client._endpoint == endpoint + Client(api_key=api_key, endpoint=endpoint) + config = Config.return_value + Auth.assert_called_once_with(config=config) + Config.assert_called_once_with(api_key=api_key, endpoint=endpoint) Session.assert_called_once() - Auth.assert_called_once_with(api_key) - Users.assert_called_once_with(endpoint, Session.return_value) - - -class TestGetApiKey: - @patch.dict("os.environ", {"CONNECT_API_KEY": "foobar"}) - def test_get_api_key(self): - api_key = _get_api_key() - assert api_key == "foobar" - - @patch.dict("os.environ", {"CONNECT_API_KEY": ""}) - def test_get_api_key_empty(self): - with pytest.raises(ValueError): - _get_api_key() - - def test_get_api_key_miss(self): - with pytest.raises(ValueError): - _get_api_key() + Users.assert_called_once_with(config=config, session=Session.return_value) - -class TestGetEndpoint: - @patch.dict("os.environ", {"CONNECT_SERVER": "http://foo.bar"}) - def test_get_endpoint(self): - endpoint = _get_endpoint() - assert endpoint == "http://foo.bar" - - @patch.dict("os.environ", {"CONNECT_SERVER": ""}) - def test_get_endpoint_empty(self): - with pytest.raises(ValueError): - _get_endpoint() - - def test_get_endpoint_miss(self): - with pytest.raises(ValueError): - _get_endpoint() + @patch("posit.connect.client.Users") + @patch("posit.connect.client.Session") + @patch("posit.connect.client.Auth") + def test_del(self, Auth: MagicMock, Session: MagicMock, Users: MagicMock): + api_key = "foobar" + endpoint = "http://foo.bar" + client = Client(api_key=api_key, endpoint=endpoint) + del client + Session.return_value.close.assert_called_once() diff --git a/src/posit/connect/config.py b/src/posit/connect/config.py new file mode 100644 index 00000000..858a80b6 --- /dev/null +++ b/src/posit/connect/config.py @@ -0,0 +1,57 @@ +import os + +from typing import Optional + + +def _get_api_key() -> str: + """Gets the API key from the environment variable 'CONNECT_API_KEY'. + + Raises: + ValueError: if CONNECT_API_KEY is not set or invalid + + Returns: + The API key + """ + value = os.environ.get("CONNECT_API_KEY") + if value is None or value == "": + raise ValueError( + "Invalid value for 'CONNECT_API_KEY': Must be a non-empty string." + ) + return value + + +def _get_endpoint() -> str: + """Gets the endpoint from the environment variable 'CONNECT_SERVER'. + + The `requests` library uses 'endpoint' instead of 'server'. We will use 'endpoint' from here forward for consistency. + + Raises: + ValueError: if CONNECT_SERVER is not set or invalid. + + Returns: + The endpoint. + """ + value = os.environ.get("CONNECT_SERVER") + if value is None or value == "": + raise ValueError( + "Invalid value for 'CONNECT_SERVER': Must be a non-empty string." + ) + return value + + +def _format_endpoint(endpoint: str) -> str: + # todo - format endpoint url and ake sure it ends with __api__ + return endpoint + + +class Config: + """Derived configuration properties""" + + api_key: str + endpoint: str + + def __init__( + self, api_key: Optional[str] = None, endpoint: Optional[str] = None + ) -> None: + self.api_key = api_key or _get_api_key() + self.endpoint = _format_endpoint(endpoint or _get_endpoint()) diff --git a/src/posit/connect/config_test.py b/src/posit/connect/config_test.py new file mode 100644 index 00000000..c95a951a --- /dev/null +++ b/src/posit/connect/config_test.py @@ -0,0 +1,46 @@ +import pytest + +from unittest.mock import patch + +from .config import Config, _get_api_key, _get_endpoint + + +class TestGetApiKey: + @patch.dict("os.environ", {"CONNECT_API_KEY": "foobar"}) + def test_get_api_key(self): + api_key = _get_api_key() + assert api_key == "foobar" + + @patch.dict("os.environ", {"CONNECT_API_KEY": ""}) + def test_get_api_key_empty(self): + with pytest.raises(ValueError): + _get_api_key() + + def test_get_api_key_miss(self): + with pytest.raises(ValueError): + _get_api_key() + + +class TestGetEndpoint: + @patch.dict("os.environ", {"CONNECT_SERVER": "http://foo.bar"}) + def test_get_endpoint(self): + endpoint = _get_endpoint() + assert endpoint == "http://foo.bar" + + @patch.dict("os.environ", {"CONNECT_SERVER": ""}) + def test_get_endpoint_empty(self): + with pytest.raises(ValueError): + _get_endpoint() + + def test_get_endpoint_miss(self): + with pytest.raises(ValueError): + _get_endpoint() + + +class TestConfig: + def test_init(self): + api_key = "foobar" + endpoint = "http://foo.bar" + config = Config(api_key=api_key, endpoint=endpoint) + assert config.api_key == api_key + assert config.endpoint == endpoint diff --git a/src/posit/connect/endpoints.py b/src/posit/connect/endpoints.py new file mode 100644 index 00000000..33dcbc96 --- /dev/null +++ b/src/posit/connect/endpoints.py @@ -0,0 +1,36 @@ +import os +import requests + +_MAX_PAGE_SIZE = 500 + + +def get_users( + endpoint: str, + session: requests.Session, + /, + page_number: int, + *, + page_size: int = 500, +): + """ + Fetches the current page of users. + + Returns: + List[User]: A list of User objects representing the fetched users. + """ + # Construct the endpoint URL. + endpoint = os.path.join(endpoint, "v1/users") + # Redefine the page number using 1-based indexing. + page_number = page_number + 1 + # Define query parameters for pagination. + params = {"page_number": page_number, "page_size": page_size} + # Send a GET request to the endpoint with the specified parameters. + response = session.get(endpoint, params=params) + # Convert response to dict + json = response.json() + # Parse the JSON response and extract the results. + results = json["results"] + # Mark exhausted if the result size is less than the maximum page size. + exhausted = len(results) < page_size + # Create User objects from the results and return them as a list. + return (results, exhausted) diff --git a/src/posit/connect/endpoints_test.py b/src/posit/connect/endpoints_test.py new file mode 100644 index 00000000..bf043e0f --- /dev/null +++ b/src/posit/connect/endpoints_test.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock, Mock, patch + +from .endpoints import get_users + + +class TestGetUsers: + @patch("posit.connect.users.Session") + def test(self, Session: MagicMock): + session = Session.return_value + get = session.get = Mock() + response = get.return_value = Mock() + json = response.json = Mock() + json.return_value = {"results": ["foo"]} + users, exhausted = get_users("http://foo.bar", session, page_number=0) + assert users == ["foo"] + assert exhausted diff --git a/src/posit/connect/users.py b/src/posit/connect/users.py index 917daff9..0bc8dc14 100644 --- a/src/posit/connect/users.py +++ b/src/posit/connect/users.py @@ -2,14 +2,18 @@ import os -from dataclasses import dataclass, asdict from datetime import datetime from requests import Session -from typing import Optional +from typing import Iterator, List, Optional, TypedDict +from .config import Config +from .endpoints import get_users +from .errors import ClientError -@dataclass -class User: +_MAX_PAGE_SIZE = 500 + + +class User(TypedDict, total=False): guid: str email: str username: str @@ -22,54 +26,133 @@ class User: confirmed: bool locked: bool - def to_dict(self) -> dict: - return asdict(self) +class Users(Iterator[User]): + def __init__( + self, config: Config, session: Session, *, users: Optional[List[User]] = None + ): + self._config = config + self._session = session -class Users(list[User]): - """An extension of :class:`list[User]` with additional fetch methods.""" + self._cached_users: List[User] = users or [] + self._exhausted: bool = users is not None + self._index: int = 0 + self._page_number: int = 0 - _endpoint: str - _session: Session + def __iter__(self) -> Iterator[User]: + """ + Initialize the iterator by resetting the index to the beginning of the cached user list. - def __init__(self, endpoint: str, session: Session): - self._endpoint = endpoint - self._session = session + Returns: + Iterator: The initialized iterator object. + """ + # Reset the index to the beginning of the cached user list. + self._index = 0 + # Return the iterator object. + return self - def find(self, params: dict = {}) -> Users: - """Finds any :class:`User` that matches the provided filter conditions + def __next__(self): + """Retrieve the next user in the list. If necessary, fetch a new page of users beforehand. - Keyword Arguments: - params -- filter conditions (default: {{}}) + Raises: + StopIteration: If the end of the user list is reached. + StopIteration: If no users are returned for the current page. Returns: - `self` + dict: Information about the next user. """ - self.clear() - endpoint = os.path.join(self._endpoint, "__api__/v1/users") - response = self._session.get(endpoint) - data = response.json() - for user in data["results"]: - if all(user.get(k) == v for k, v in params.items()): - self.append(User(**user)) - # todo - implement paging and caching - return self + # Check if the current index is greater than or equal to the length of the cached user list. + if self._index >= len(self._cached_users): + # Check if the endpoint was exhausted on the previous iteration + if self._exhausted: + # Stop iteration if the index is not aligned with page boundaries. + raise StopIteration + # Fetch the current page of users. + results, exhausted = get_users( + self._config.endpoint, self._session, self._page_number + ) + # Mark if the endpoint is exhausted for the next iteration + self._exhausted = exhausted + # Increment the page counter for the next iteration. + self._page_number += 1 + # Append the fetched users to the cached user list. + self._cached_users += [User(**result) for result in results] + # Check if the fetched results list is empty. + if not results: + # Stop iteration if no users are returned for the current page. + raise StopIteration + # Get the current user by index. + user = self._cached_users[self._index] + # Increment the index for the next iteration. + self._index += 1 + # Return the current user. + return user + + def find(self, params: User) -> Users: + """ + Finds users that match the provided filter conditions. + + Args: + params (User): Filter conditions. - def find_one(self, params: dict = {}) -> Optional[User]: - """Finds one :class:`User` + Returns: + Users: A list of users matching the filter conditions. + """ + found: List[User] = [] + for user in self: + # Check if the items in params are subset of user's items. + if params.items() <= user.items(): + # Append the user to the found list. + found.append(user) + return Users(self._config, self._session, users=found) + + def find_one(self, params: User) -> Optional[User]: + """ + Finds one User matching the provided parameters. Keyword Arguments: - params -- filter conditions (default: {{}}) + params -- Dictionary of filter conditions (default: {}). Returns: - A matching :class:`User`. + A matching User if found, otherwise None. + + Note: + This method first checks if 'guid' is present in params. If so, it attempts a direct lookup using self.get(). + If an error with code '4' is encountered (indicating no matching user), it logs a warning and returns None. + If 'guid' is not provided, it performs a normal search using self.find() and return the first value found. """ + # Check if 'guid' is provided in params if "guid" in params: - # Use the user details API if a 'guid' is provided. - # This is an example of how we can use different API endpoints to optimize execution time. - endpoint = os.path.join(self._endpoint, "__api__/v1/users", params["guid"]) - response = self._session.get(endpoint) - return User(**response.json()) - - # Otherwise, perform a normal search. + try: + # Attempt direct lookup + self.get(params["guid"]) + except ClientError as e: + # Check for error code '4' (no matching user) + if e.error_code == 4: + import logging + + logging.warning(e) + # Return None if user not found + return None + raise e + + # If 'guid' not provided perform a normal search return next(iter(self.find(params)), None) + + def get(self, guid: str) -> User: + """Gets a user by guid. + + Arguments: + guid -- the users guid. + + Returns: + A :class:`User`. + """ + endpoint = os.path.join(self._config.endpoint, "v1/users", guid) + response = self._session.get(endpoint) + return User(**response.json()) + + def to_pandas_data_frame(self): # noqa + import pandas as pd + + return pd.DataFrame((user for user in self)) diff --git a/src/posit/connect/users_test.py b/src/posit/connect/users_test.py index e69de29b..aa4f49c7 100644 --- a/src/posit/connect/users_test.py +++ b/src/posit/connect/users_test.py @@ -0,0 +1,150 @@ +import pytest + +from unittest.mock import MagicMock, patch + +from .users import Users, User + + +class TestUsers: + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_init(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + session = Session.return_value + users = Users(config, session) + assert users._config == config + assert users._session == session + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_iter(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + session = Session.return_value + users = Users(config, session) + iter(users) + assert users._index == 0 + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_next_with_empty_result_set(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + session = Session.return_value + users = Users(config, session) + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [], True + with pytest.raises(StopIteration): + next(users) + + assert users._cached_users == [] + assert users._exhausted is True + assert users._index == 0 + assert users._page_number == 1 + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_next_with_single_page(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user: User = {} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], True + assert next(users) == user + get_users.assert_called_with(config.endpoint, session, 0) + + with pytest.raises(StopIteration): + next(users) + + assert users._cached_users == [user] + assert users._exhausted is True + assert users._index == 1 + assert users._page_number == 1 + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_next_with_multiple_pages(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user: User = {} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], False + assert next(users) == user + get_users.assert_called_with(config.endpoint, session, 0) + + get_users.return_value = [user], True + assert next(users) == user + get_users.assert_called_with(config.endpoint, session, 1) + + assert users._cached_users == [user, user] + assert users._exhausted is True + assert users._index == 2 + assert users._page_number == 2 + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_find(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user = {"username": "foobar"} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], True + found = users.find({"username": "foobar"}) + assert list(found) == [user] + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_find_miss(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user = {"username": "foo"} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], True + assert list(users.find({"username": "bar"})) == [] + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_find_one(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user = {"username": "foobar"} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], True + assert users.find_one({"username": "foobar"}) == user + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_find_one_miss(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + session = Session.return_value + users = Users(config, session) + user = {"username": "foo"} + with patch("posit.connect.users.get_users") as get_users: + get_users.return_value = [user], True + assert users.find_one({"username": "bar"}) is None + + @patch("posit.connect.users.Session") + @patch("posit.connect.users.Config") + def test_get(self, Config: MagicMock, Session: MagicMock): + config = Config.return_value + config.endpoint = "http://foo.bar" + + user = {"guid": "foobar"} + response = MagicMock() + response.json = MagicMock() + response.json.return_value = user + session = Session.return_value + session.get = MagicMock() + session.get.return_value = response + + users = Users(config, session) + assert users.get("foobar") == user diff --git a/tinkering.py b/tinkering.py index c507c41d..74be435e 100644 --- a/tinkering.py +++ b/tinkering.py @@ -1,9 +1,12 @@ -from posit.connect import make_client +from posit.connect.client import create_client -client = make_client() -for user in client.users.find({"username": "aaron"}): - print(user) - -print(client.users.find_one()) - -print(client.users.find_one({"guid": "f155520a-ca2e-4084-b0a0-12120b7d1add"})) +with create_client() as client: + users = client.users + print( + users.find({"username": "taylor_steinberg"}).find( + {"username": "taylor_steinberg"} + ) + ) + print(users.find_one({"username": "taylor_steinberg"})) + print(users.get("f55ca95d-ce52-43ed-b31b-48dc4a07fe13")) + print(users.to_pandas_data_frame())