From 89bb161ed74aa83690f34de45fe6dfd3ac064a42 Mon Sep 17 00:00:00 2001 From: Yevhenii Vaskivskyi Date: Sun, 19 Nov 2023 13:57:02 +0100 Subject: [PATCH] Add some tests and code improvements (#377) --- .coveragerc | 5 + asusrouter/modules/endpoint/__init__.py | 3 +- asusrouter/modules/endpoint/command.py | 2 +- asusrouter/modules/endpoint/devicemap.py | 97 ++++--- requirements_test.txt | 1 + tests/__init__.py | 2 +- tests/modules/__init__.py | 1 + tests/modules/endpoint/__init__.py | 1 + tests/modules/endpoint/test_command.py | 27 ++ tests/modules/endpoint/test_devicemap.py | 345 +++++++++++++++++++++++ tests/modules/endpoint/test_endpoint.py | 183 ++++++++++++ tests/tools/test_readers.py | 9 +- 12 files changed, 630 insertions(+), 46 deletions(-) create mode 100644 .coveragerc create mode 100644 tests/modules/__init__.py create mode 100644 tests/modules/endpoint/__init__.py create mode 100644 tests/modules/endpoint/test_command.py create mode 100644 tests/modules/endpoint/test_devicemap.py create mode 100644 tests/modules/endpoint/test_endpoint.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..69b9a42 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +source = asusrouter +omit = + # Constants modules + asusrouter/modules/endpoint/devicemap_const.py diff --git a/asusrouter/modules/endpoint/__init__.py b/asusrouter/modules/endpoint/__init__.py index fa4440b..1693e94 100644 --- a/asusrouter/modules/endpoint/__init__.py +++ b/asusrouter/modules/endpoint/__init__.py @@ -11,7 +11,6 @@ from asusrouter.error import AsusRouter404Error from asusrouter.modules.data import AsusData, AsusDataState from asusrouter.modules.firmware import Firmware -from asusrouter.modules.flags import Flag from asusrouter.modules.wlan import Wlan _LOGGER = logging.getLogger(__name__) @@ -126,7 +125,7 @@ def data_set(data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: return data -def data_get(data: dict[str, Any], key: str) -> Any: +def data_get(data: dict[str, Any], key: str) -> Optional[Any]: """Extract value from the data dict and update the data dict.""" # Get the value diff --git a/asusrouter/modules/endpoint/command.py b/asusrouter/modules/endpoint/command.py index 8013fb9..9e2dd06 100644 --- a/asusrouter/modules/endpoint/command.py +++ b/asusrouter/modules/endpoint/command.py @@ -7,7 +7,7 @@ from asusrouter.tools.readers import read_json_content -def read(content: str) -> dict[str, Any]: # pylint: disable=unused-argument +def read(content: str) -> dict[str, Any]: """Read state data""" # Read the json content diff --git a/asusrouter/modules/endpoint/devicemap.py b/asusrouter/modules/endpoint/devicemap.py index 2bc6ee4..7e1a35a 100644 --- a/asusrouter/modules/endpoint/devicemap.py +++ b/asusrouter/modules/endpoint/devicemap.py @@ -28,8 +28,11 @@ def read(content: str) -> dict[str, Any]: devicemap: dict[str, Any] = {} # Parse the XML data - xml_content: dict[str, Any] = xmltodict.parse(content).get("devicemap", {}) - if not xml_content: + try: + xml_content: dict[str, Any] = xmltodict.parse(content).get("devicemap", {}) + if not xml_content: + return devicemap + except xmltodict.expat.ExpatError: # type: ignore return devicemap # Go through the data and fill the dict @@ -42,6 +45,8 @@ def read(content: str) -> dict[str, Any]: # Clear values from useless symbols for output_group, clear_map in DEVICEMAP_CLEAR.items(): + if output_group not in devicemap: + continue for key, clear_value in clear_map.items(): # If the key is not in the devicemap, continue if key not in devicemap[output_group]: @@ -61,22 +66,30 @@ def read(content: str) -> dict[str, Any]: return devicemap -# This method performs reading of the devicemap by index -# to simplify the original read_devicemap method def read_index(xml_content: dict[str, Any]) -> dict[str, Any]: - """Read devicemap by index.""" + """Read devicemap by index. + + This method performs reading of the devicemap by index + to simplify the original read_devicemap method.""" # Create a dict to store the data devicemap: dict[str, Any] = {} # Get values for which we only know their order (index) for output_group, input_group, input_values in DEVICEMAP_BY_INDEX: - # Create a dict to store the data - output_group_data: dict[str, Any] = {} + # Create an empty dictionary for the output group + devicemap[output_group] = {} - # Go through the input values and fill the dict - for index, input_value in enumerate(input_values): - output_group_data[input_value] = xml_content[input_group][index] + # Check that the input group is in the xml content + if input_group not in xml_content: + continue + + # Use dict comprehension to build output_group_data + output_group_data = { + input_value: xml_content[input_group][index] + for index, input_value in enumerate(input_values) + if index < len(xml_content[input_group]) + } # Add the output group data to the devicemap devicemap[output_group] = output_group_data @@ -85,10 +98,11 @@ def read_index(xml_content: dict[str, Any]) -> dict[str, Any]: return devicemap -# This method performs reading of the devicemap by key -# to simplify the original read_devicemap method def read_key(xml_content: dict[str, Any]) -> dict[str, Any]: - """Read devicemap by key.""" + """Read devicemap by key. + + This method performs reading of the devicemap by key + to simplify the original read_devicemap method.""" # Create a dict to store the data devicemap: dict[str, Any] = {} @@ -100,27 +114,25 @@ def read_key(xml_content: dict[str, Any]) -> dict[str, Any]: # Go through the input values and fill the dict for input_value in input_values: - # Check if the input group is a string - if isinstance(xml_content.get(input_group), str): - # Check if the input value is in the input group - if input_value in xml_content[input_group]: - # Add the input value to the output group data and remove the key - output_group_data[input_value] = xml_content[input_group].replace( + # Get the input group data + xml_input_group = xml_content.get(input_group) + + # If the input group data is None, skip this iteration + if xml_input_group is None: + continue + + # If the input group data is a string, convert it to a list + if isinstance(xml_input_group, str): + xml_input_group = [xml_input_group] + + # Go through the input group data and check if the input value is in it + for value in xml_input_group: + if input_value in value: + # Add the input value to the output group data + output_group_data[input_value] = value.replace( f"{input_value}=", "" ) - # Check if the input group is a list - else: - # Go through the input group and check if the input value is in it - xml_input_group = xml_content.get(input_group) - if not xml_input_group: - continue - for value in xml_input_group: - if input_value in value: - # Add the input value to the output group data - output_group_data[input_value] = value.replace( - f"{input_value}=", "" - ) - break + break # Add the output group data to the devicemap devicemap[output_group] = output_group_data @@ -142,17 +154,24 @@ def read_special(xml_content: dict[str, Any]) -> dict[str, Any]: def read_uptime_string(content: str) -> datetime | None: """Read uptime string and return proper datetime object.""" + # Split the content into the date/time part and the seconds part + uptime_parts = content.split("(") + if len(uptime_parts) < 2: + return None + + # Extract the number of seconds from the seconds part + seconds_match = re.search("([0-9]+)", uptime_parts[1]) + if not seconds_match: + return None + try: - part = content.split("(") - match = re.search("([0-9]+)", part[1]) - if not match: - return None - seconds = int(match.group()) - when = dtparse(part[0]) - uptime = when - timedelta(seconds=seconds) + seconds = int(seconds_match.group()) + when = dtparse(uptime_parts[0]) except ValueError: return None + uptime = when - timedelta(seconds=seconds) + return uptime diff --git a/requirements_test.txt b/requirements_test.txt index c00e319..032a155 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,3 +1,4 @@ # Pytest for running tests pytest>=7.4.3 +pytest-asyncio>=0.21.1 pytest-cov>=4.1.0 diff --git a/tests/__init__.py b/tests/__init__.py index a8991d4..2d555c0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Tests for AsusRouter""" +"""Tests for AsusRouter.""" diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 0000000..f6f3f2d --- /dev/null +++ b/tests/modules/__init__.py @@ -0,0 +1 @@ +"""Tests for the AsusRouter modules.""" diff --git a/tests/modules/endpoint/__init__.py b/tests/modules/endpoint/__init__.py new file mode 100644 index 0000000..2ed07c9 --- /dev/null +++ b/tests/modules/endpoint/__init__.py @@ -0,0 +1 @@ +"""Tests for AsusRouter endpoint module.""" diff --git a/tests/modules/endpoint/test_command.py b/tests/modules/endpoint/test_command.py new file mode 100644 index 0000000..704112d --- /dev/null +++ b/tests/modules/endpoint/test_command.py @@ -0,0 +1,27 @@ +"""Test AsusRouter command endpoint module.""" + +from unittest.mock import patch + +from asusrouter.modules.endpoint import command + + +def test_read(): + """Test read function.""" + + # Test data + content = '{"key1": "value1", "key2": "value2"}' + expected_command = {"key1": "value1", "key2": "value2"} + + # Mock the read_json_content function + with patch( + "asusrouter.modules.endpoint.command.read_json_content", + return_value=expected_command, + ) as mock_read_json_content: + # Call the function + result = command.read(content) + + # Check the result + assert result == expected_command + + # Check that read_json_content was called with the correct argument + mock_read_json_content.assert_called_once_with(content) diff --git a/tests/modules/endpoint/test_devicemap.py b/tests/modules/endpoint/test_devicemap.py new file mode 100644 index 0000000..fa2af8b --- /dev/null +++ b/tests/modules/endpoint/test_devicemap.py @@ -0,0 +1,345 @@ +"""Test AsusRouter devicemap endpoint module.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from asusrouter.modules.data import AsusData, AsusDataState +from asusrouter.modules.endpoint import devicemap + + +def generate_xml_content(groups): + """Generate XML content based on input parameters.""" + content = "\n" + for group, keys in groups.items(): + content += f" <{group}>\n" + for key, value in keys.items(): + content += f" <{key}>{value}\n" + content += f" \n" + content += "\n" + return content + + +@pytest.mark.parametrize( + "content", + [ + "", # Empty devicemap + "non-xml", # Invalid XML + ], +) +def test_read_invalid(content): + """Test read function with empty devicemap.""" + + assert devicemap.read(content) == {} + + +@pytest.fixture +def common_group(): + """Return the common test data intermediate.""" + + return { + "group1": {"key1": "value1"}, + "group2": {"key2": "value2"}, + "group3": {"key3": "value3_test"}, + } + + +@pytest.fixture +def common_test_data_result(): + """Return the common test data result.""" + + return { + "group1": {"key1": "value1"}, + "group2": {"key2": "value2"}, + "group3": {"key3": "value3"}, + } + + +@pytest.fixture +def mock_functions(): + """Return the mock functions.""" + + with patch( + "asusrouter.modules.endpoint.devicemap.read_index", + return_value={"group1": {"key1": "value1"}, "group3": {"key3": "value3_test"}}, + ) as mock_read_index, patch( + "asusrouter.modules.endpoint.devicemap.read_key", + return_value={"group2": {"key2": "value2"}}, + ) as mock_read_key, patch( + "asusrouter.modules.endpoint.devicemap.merge_dicts", + side_effect=lambda x, y: {**x, **y}, + ) as mock_merge_dicts, patch( + "asusrouter.modules.endpoint.devicemap.clean_dict", side_effect=lambda x: x + ) as mock_clean_dict, patch( + "asusrouter.modules.endpoint.devicemap.clean_dict_key_prefix", + side_effect=lambda x, _: x, + ) as mock_clean_dict_key_prefix, patch( + "asusrouter.modules.endpoint.devicemap.DEVICEMAP_CLEAR", + new={"group3": {"key3": "_test", "key4": "_test"}, "group4": {"key5": "_test"}}, + ) as mock_devicemap_clear: + yield { + "read_index": mock_read_index, + "read_key": mock_read_key, + "merge_dicts": mock_merge_dicts, + "clean_dict": mock_clean_dict, + "clean_dict_key_prefix": mock_clean_dict_key_prefix, + "devicemap_clear": mock_devicemap_clear, + } + + +def test_read_with_data( + mock_functions, # pylint: disable=redefined-outer-name + common_test_data_result, # pylint: disable=redefined-outer-name + common_group, # pylint: disable=redefined-outer-name +): + """Test read function.""" + + # Test data + content = generate_xml_content(common_group) + expected_devicemap = common_test_data_result + + # Call the function + result = devicemap.read(content) + + # Check the result + assert result == expected_devicemap + + # Check the calls to the mocked functions + mock_functions["read_index"].assert_called_with(common_group) + mock_functions["read_key"].assert_called_with(common_group) + assert mock_functions["merge_dicts"].call_count == 2 + mock_functions["clean_dict"].assert_called_with(expected_devicemap) + assert mock_functions["clean_dict_key_prefix"].call_count == 3 + + +@pytest.fixture +def const_devicemap(): + """Return the const devicemap.""" + + return [ + ("output_group1", "group1", ["input_value1"]), + ("output_group2", "group2", ["input_value3", "input_value4"]), + ("output_group3", "group3", ["input_value5"]), + ("output_group4", "group4", ["input_value6"]), + ("output_group5", "group5", ["input_value2"]), + ("output_group6", "group6", ["input_value7"]), + ] + + +@pytest.fixture +def const_devicemap_result(): + """Return the const devicemap result.""" + + return { + "output_group1": {"input_value1": "value1"}, + "output_group2": {"input_value3": "value3", "input_value4": "value4"}, + "output_group3": {"input_value5": "value5"}, + "output_group4": {"input_value6": "value6"}, + "output_group5": {}, + "output_group6": {}, + } + + +@pytest.fixture +def input_data(): + """Return the input data for the tests.""" + return { + "group1": ["value1"], + "group2": ["value3", "value4"], + "group3": ["value5"], + "group4": ["value6"], + # "group5": [], + "group6": [], + } + + +@pytest.fixture +def input_data_key(): + """Return the input data for the read_key test.""" + return { + "group1": ["input_value1=value1"], + "group2": ["input_value3=value3", "input_value4=value4"], + "group3": "input_value5=value5", + "group4": "input_value6=value6", + "group5": None, + "group6": [], + } + + +def test_read_index( + const_devicemap, # pylint: disable=redefined-outer-name + const_devicemap_result, # pylint: disable=redefined-outer-name + input_data, # pylint: disable=redefined-outer-name +): + """Test read_index function.""" + with patch.object(devicemap, "DEVICEMAP_BY_INDEX", new=const_devicemap): + # Call the function + result = devicemap.read_index(input_data) + + # Check the result + assert result == const_devicemap_result + + +def test_read_key( + const_devicemap, # pylint: disable=redefined-outer-name + const_devicemap_result, # pylint: disable=redefined-outer-name + input_data_key, # pylint: disable=redefined-outer-name +): + """Test read_key function.""" + with patch.object(devicemap, "DEVICEMAP_BY_KEY", new=const_devicemap): + # Call the function + result = devicemap.read_key(input_data_key) + + # Check the result + assert result == const_devicemap_result + + +def test_read_special( + input_data, # pylint: disable=redefined-outer-name +): + """Test read_special function.""" + + result = devicemap.read_special(input_data) + assert result == {} # pylint: disable=C1803 + + +@pytest.mark.parametrize( + "content, result", + [ + # Test with a valid content string + ( + "Thu, 16 Nov 2023 07:17:45 +0100(219355 secs since boot)", + datetime(2023, 11, 16, 7, 17, 45, tzinfo=timezone(timedelta(hours=1))) + - timedelta(seconds=219355), + ), + # Test with an invalid content string (no seconds) + ("Thu, 16 Nov 2023 07:17:45 +0100(no secs since boot)", None), + # Test with an invalid content string (bad format) + ("bad format", None), + # Test with a content string that has an invalid date + ("Not a date (219355 secs since boot)", None), + # Test with a content string that has an invalid number of seconds + ("Thu, 16 Nov 2023 07:17:45 +0100(not a number secs since boot)", None), + ], +) +def test_read_uptime_string(content, result): + """Test read_uptime_string function.""" + + assert devicemap.read_uptime_string(content) == result + + +@pytest.mark.parametrize( + "boottime_return, expected_flags", + [ + (("boottime", False), {}), + (("boottime", True), {"reboot": True}), + ], +) +@patch("asusrouter.modules.endpoint.devicemap.process_boottime") +@patch("asusrouter.modules.endpoint.devicemap.process_ovpn") +def test_process( + mock_process_ovpn, + mock_process_boottime, + boottime_return, + expected_flags, +): + """Test process function.""" + + # Prepare the mock functions + mock_process_boottime.return_value = boottime_return + mock_process_ovpn.return_value = "openvpn" + + # Prepare the test data + data = {"history": {AsusData.BOOTTIME: AsusDataState(data="prev_boottime")}} + + # Call the function with the test data + result = devicemap.process(data) + + # Check the result + assert result == { + AsusData.DEVICEMAP: data, + AsusData.BOOTTIME: boottime_return[0], + AsusData.OPENVPN: "openvpn", + AsusData.FLAGS: expected_flags, + } + + # Check that the mock functions were called with the correct arguments + mock_process_boottime.assert_called_once_with(data, "prev_boottime") + mock_process_ovpn.assert_called_once_with(data) + + +@pytest.mark.parametrize( + "prev_boottime_delta, expected_result", + [ + (timedelta(seconds=1), ({"datetime": ANY}, False)), + (timedelta(seconds=3), ({"datetime": ANY}, True)), + ], +) +@patch("asusrouter.modules.endpoint.devicemap.read_uptime_string") +def test_process_boottime( + mock_read_uptime_string, prev_boottime_delta, expected_result +): + """Test process_boottime function.""" + + # Prepare the mock function + mock_read_uptime_string.return_value = datetime.now() + + # Prepare the test data + devicemap_data = {"sys": {"uptimeStr": "uptime string"}} + prev_boottime = {"datetime": datetime.now() - prev_boottime_delta} + + # Call the function with the test data + result = devicemap.process_boottime(devicemap_data, prev_boottime) + + # Check the result + assert result == expected_result + + # Check that the mock function was called with the correct argument + mock_read_uptime_string.assert_called_once_with("uptime string") + + +@patch("asusrouter.modules.endpoint.devicemap.AsusOVPNClient") +@patch("asusrouter.modules.endpoint.devicemap.AsusOVPNServer") +@patch("asusrouter.modules.endpoint.devicemap.safe_int") +def test_process_ovpn(mock_safe_int, mock_asusovpnserver, mock_asusovnclient): + """Test process_ovpn function.""" + + # Prepare the mock functions + mock_asusovnclient.return_value = MagicMock() + mock_asusovpnserver.return_value = MagicMock() + mock_safe_int.return_value = 0 + + # Prepare the test data + devicemap_data = { + "vpn": { + "client1_state": "state", + "client1_errno": "errno", + "server1_state": "state", + } + } + + # Call the function with the test data + result = devicemap.process_ovpn(devicemap_data) + + # Check the result + expected_result = { + "client": { + 1: { + "state": mock_asusovnclient.return_value, + "errno": 0, + } + }, + "server": { + 1: { + "state": mock_asusovpnserver.return_value, + } + }, + } + assert result == expected_result + + # Check that the mock functions were called with the correct arguments + mock_asusovnclient.assert_called_once_with(0) + mock_asusovpnserver.assert_called_once_with(0) + mock_safe_int.assert_any_call("state", default=0) + mock_safe_int.assert_any_call("errno") diff --git a/tests/modules/endpoint/test_endpoint.py b/tests/modules/endpoint/test_endpoint.py new file mode 100644 index 0000000..91affef --- /dev/null +++ b/tests/modules/endpoint/test_endpoint.py @@ -0,0 +1,183 @@ +"""Test for the main endpoint module.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from asusrouter.error import AsusRouter404Error +from asusrouter.modules.endpoint import ( + Endpoint, + _get_module, + check_available, + data_get, + data_set, + process, + read, +) + + +def test_get_module(): + """Test _get_module method.""" + + # Test valid endpoint + with patch("importlib.import_module") as mock_import: + mock_import.return_value = "mocked_module" + result = _get_module(Endpoint.RGB) + assert result == "mocked_module" + mock_import.assert_called_once_with("asusrouter.modules.endpoint.rgb") + + # Test invalid endpoint + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ModuleNotFoundError + result = _get_module(Endpoint.FIRMWARE) + assert result is None + mock_import.assert_called_once_with("asusrouter.modules.endpoint.firmware") + + +def test_read(): + """Test read method.""" + + # Mock the module and its read method + mock_module = MagicMock() + mock_module.read.return_value = {"mocked": "data"} + + # Test valid endpoint + with patch( + "asusrouter.modules.endpoint._get_module", return_value=mock_module + ) as mock_get_module: + result = read(Endpoint.FIRMWARE, "content") + assert result == {"mocked": "data"} + mock_get_module.assert_called_once_with(Endpoint.FIRMWARE) + mock_module.read.assert_called_once_with("content") + + # Test invalid endpoint + with patch( + "asusrouter.modules.endpoint._get_module", return_value=None + ) as mock_get_module: + result = read(Endpoint.RGB, "content") + assert result == {} + mock_get_module.assert_called_once_with(Endpoint.RGB) + + +@pytest.mark.parametrize( + "require_history,require_firmware,require_wlan,call_count", + [ + (True, False, False, 1), + (False, True, False, 1), + (False, False, True, 1), + (False, False, False, 0), + (True, True, True, 3), + ], +) +def test_process(require_history, require_firmware, require_wlan, call_count): + """Test process method.""" + + # Mock the module and its process method + mock_module = MagicMock() + mock_module.process.return_value = {"mocked": "data"} + + # Mock the data_set function + mock_data_set = MagicMock() + + # Define a side effect function for getattr + def getattr_side_effect(_, attr, default=None): + if attr == "REQUIRE_HISTORY": + return require_history + if attr == "REQUIRE_FIRMWARE": + return require_firmware + if attr == "REQUIRE_WLAN": + return require_wlan + return default + + # Test valid endpoint + with patch( + "asusrouter.modules.endpoint._get_module", return_value=mock_module + ), patch("asusrouter.modules.endpoint.data_set", mock_data_set), patch( + "asusrouter.modules.endpoint.getattr", side_effect=getattr_side_effect + ): + result = process(Endpoint.DEVICEMAP, {"key": "value"}) + assert result == {"mocked": "data"} + mock_module.process.assert_called_once_with({"key": "value"}) + assert mock_data_set.call_count == call_count + + +def test_process_no_module(): + """Test process method when no module is found.""" + + # Mock the _get_module function to return None + with patch( + "asusrouter.modules.endpoint._get_module", return_value=None + ) as mock_get_module: + result = process(Endpoint.RGB, {"key": "value"}) + assert result == {} + mock_get_module.assert_called_once_with(Endpoint.RGB) + + +def test_data_set(): + """Test data_set function.""" + + # Test data + data = {"key1": "value1"} + kwargs = {"key2": "value2", "key3": "value3"} + + # Call the function + result = data_set(data, **kwargs) + + # Check the result + assert result == {"key1": "value1", "key2": "value2", "key3": "value3"} + + +@pytest.mark.parametrize( + "data, key, expected, data_left", + [ + # Key exists + ({"key1": "value1", "key2": "value2"}, "key1", "value1", {"key2": "value2"}), + # Key does not exist + ( + {"key1": "value1", "key2": "value2"}, + "key3", + None, + {"key1": "value1", "key2": "value2"}, + ), + # Empty data + ({}, "key1", None, {}), + ], +) +def test_data_get(data, key, expected, data_left): + """Test data_get function.""" + + # Call the function + result = data_get(data, key) + + # Check the result + assert result == expected + assert data == data_left + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "api_query_return, expected_result", + [ + # Test case: status 200 + ((200, None, None), True), + # Test case: status not 200 + ((403, None, None), False), + # Test case: AsusRouter404Error is raised + (AsusRouter404Error(), False), + ], +) +async def test_check_available(api_query_return, expected_result): + """Test check_available function.""" + + # Mock the api_query function + api_query = AsyncMock() + if isinstance(api_query_return, Exception): + api_query.side_effect = api_query_return + else: + api_query.return_value = api_query_return + + # Call the function + result = await check_available(Endpoint.DEVICEMAP, api_query) + + # Check the result + assert result == expected_result diff --git a/tests/tools/test_readers.py b/tests/tools/test_readers.py index 7062a07..39eb4c9 100644 --- a/tests/tools/test_readers.py +++ b/tests/tools/test_readers.py @@ -24,6 +24,7 @@ ) def test_merge_dicts(dict1, dict2, expected): """Test merge_dicts method.""" + assert readers.merge_dicts(dict1, dict2) == expected @@ -141,11 +142,12 @@ def test_read_js_variables(content, expected): ) def test_read_json_content(content, expected): """Test read_json_content method.""" + assert readers.read_json_content(content) == expected @pytest.mark.parametrize( - "input, expected", + "content, expected", [ # Test valid MAC addresses ("01:23:45:67:89:AB", True), @@ -160,6 +162,7 @@ def test_read_json_content(content, expected): (None, False), ], ) -def test_readable_mac(input, expected): +def test_readable_mac(content, expected): """Test readable_mac method.""" - assert readers.readable_mac(input) == expected + + assert readers.readable_mac(content) == expected