diff --git a/asusrouter/asusrouter.py b/asusrouter/asusrouter.py index fe21b27..6489dd8 100644 --- a/asusrouter/asusrouter.py +++ b/asusrouter/asusrouter.py @@ -706,12 +706,12 @@ async def async_set_state( await self._async_check_state_dependency(state) result = await set_state( - self.async_run_service, - state, - arguments, - expect_modify, - self._state, - self._identity, + callback=self.async_run_service, + state=state, + arguments=arguments, + expect_modify=expect_modify, + router_state=self._state, + identity=self._identity, ) if result is True: diff --git a/asusrouter/modules/state.py b/asusrouter/modules/state.py index 180329d..f1881a5 100644 --- a/asusrouter/modules/state.py +++ b/asusrouter/modules/state.py @@ -6,17 +6,17 @@ import logging from enum import Enum from types import ModuleType -from typing import Any, Awaitable, Callable, Optional, Tuple +from typing import Any, Awaitable, Callable, Optional from asusrouter.modules.connection import ConnectionState from asusrouter.modules.data import AsusData, AsusDataState -from asusrouter.modules.identity import AsusDevice from asusrouter.modules.parental_control import AsusParentalControl from asusrouter.modules.port_forwarding import AsusPortForwarding from asusrouter.modules.system import AsusSystem from asusrouter.modules.vpnc import AsusVPNC from asusrouter.modules.wireguard import AsusWireGuardClient, AsusWireGuardServer from asusrouter.modules.wlan import AsusWLAN +from asusrouter.tools.converters import get_enum_key_by_value from .led import AsusLED from .openvpn import AsusOVPNClient, AsusOVPNServer @@ -69,6 +69,10 @@ class AsusState(Enum): def add_conditional_state(state: AsusState, data: AsusData) -> None: """A callback to add / change AsusStateMap.""" + if not isinstance(state, AsusState) or not isinstance(data, AsusData): + _LOGGER.debug("Invalid state or data type: %s -> %s", state, data) + return + AsusStateMap[state] = data _LOGGER.debug("Added conditional state rule: %s -> %s", state, data) @@ -76,12 +80,9 @@ def add_conditional_state(state: AsusState, data: AsusData) -> None: def get_datatype(state: Optional[Any]) -> Optional[AsusData]: """Get the datatype.""" - asus_state = AsusState(type(state)) + asus_state = get_enum_key_by_value(AsusState, type(state), default=AsusState.NONE) - if state is not None: - return AsusStateMap.get(asus_state) - - return None + return AsusStateMap.get(asus_state) def _get_module_name(state: AsusState) -> Optional[str]: @@ -127,10 +128,7 @@ def _has_method(module: ModuleType, method: str) -> bool: async def set_state( callback: Callable[..., Awaitable[bool]], state: AsusState, - arguments: Optional[dict[str, Any]] = None, - expect_modify: bool = False, - router_state: Optional[dict[AsusData, AsusDataState]] = None, - identity: Optional[AsusDevice] = None, + **kwargs: Any, ) -> bool: """Set the state.""" @@ -140,19 +138,16 @@ async def set_state( # Process the data if module found if submodule and _has_method(submodule, "set_state"): # Determine the extra parameter - extra_param: Optional[dict[AsusData, AsusDataState] | AsusDevice] = None if getattr(submodule, "REQUIRE_STATE", False): - extra_param = router_state - elif getattr(submodule, "REQUIRE_IDENTITY", False): - extra_param = identity + kwargs["extra_param"] = kwargs.get("router_state") + if getattr(submodule, "REQUIRE_IDENTITY", False): + kwargs["extra_param"] = kwargs.get("identity") # Call the function with the determined parameters return await submodule.set_state( callback=callback, state=state, - arguments=arguments, - expect_modify=expect_modify, - extra_param=extra_param, + **kwargs, ) return False @@ -168,7 +163,7 @@ def save_state( # Get the correct data key datatype = get_datatype(state) - if datatype is None: + if datatype is None or datatype not in library: return # Save the state @@ -179,7 +174,7 @@ def save_state( async def keep_state( callback: Callable[..., Awaitable[Any]], states: Optional[AsusState | list[AsusState]], - identity: Optional[AsusDevice], + **kwargs: Any, ) -> None: """Keep the state.""" @@ -187,14 +182,15 @@ async def keep_state( return # Make sure the state is a list - if not isinstance(states, list): - states = [states] + states = [states] if not isinstance(states, list) else states # Process each state - for state in states: - # Get the module - submodule = _get_module(state) - - # Process the data if module found - if submodule and _has_method(submodule, "keep_state"): - await submodule.keep_state(callback, state, identity) + awaitables = [ + submodule.keep_state(callback, state, **kwargs) + for state in states + if (submodule := _get_module(state)) and _has_method(submodule, "keep_state") + ] + + # Execute all awaitables + for awaitable in awaitables: + await awaitable diff --git a/tests/modules/test_state.py b/tests/modules/test_state.py new file mode 100644 index 0000000..f56ec09 --- /dev/null +++ b/tests/modules/test_state.py @@ -0,0 +1,281 @@ +"""Tests for the state module.""" + +from typing import Any +from unittest import mock + +import pytest + +from asusrouter import AsusData +from asusrouter.modules.state import ( + AsusState, + _get_module, + _get_module_name, + _has_method, + add_conditional_state, + get_datatype, + keep_state, + save_state, + set_state, +) +from asusrouter.modules.system import AsusSystem +from asusrouter.modules.vpnc import AsusVPNC +from asusrouter.modules.wlan import AsusWLAN + +mock_state_map = { + AsusState.SYSTEM: AsusData.SYSTEM, + AsusState.VPNC: AsusData.VPNC, + AsusState.WLAN: AsusData.WLAN, + AsusState.NONE: None, +} + + +class MockModule: + """Mock module.""" + + def __init__(self, require_state=False, require_identity=False): + """Initialize the mock module.""" + + self.REQUIRE_STATE = require_state # pylint: disable=invalid-name + self.REQUIRE_IDENTITY = require_identity # pylint: disable=invalid-name + + self.update_state = mock.Mock() + self.offset_time = mock.Mock() + + async def set_state(self, **_: Any): + """Set the state.""" + + return True + + async def keep_state(self, *_: Any, **__: Any): + """Keep the state.""" + + return None + + +@pytest.mark.parametrize( + "state, data, success", + [ + # Existing values of AsusState + (AsusState.LED, AsusData.LED, True), + (AsusState.WIREGUARD_CLIENT, AsusData.WIREGUARD_CLIENT, True), + (AsusState.PARENTAL_CONTROL, AsusData.PARENTAL_CONTROL, True), + # Partial data + (AsusState.PORT_FORWARDING, None, False), + (None, AsusData.LED, False), + # None + (None, None, False), + # Wrong types + (1, AsusData.SYSTEM, False), + (AsusState.CONNECTION, 1, False), + ], +) +def test_add_conditional_state(state, data, success): + """Test add_conditional_state.""" + + # Try to add the state + with mock.patch("asusrouter.modules.state.AsusStateMap", mock_state_map): + add_conditional_state(state, data) + + # Check the result + if success: + assert mock_state_map[state] == data + else: + assert state not in mock_state_map + + +@pytest.mark.parametrize( + "state, expected", + [ + # Existing values of AsusState + (AsusSystem.REBOOT, AsusData.SYSTEM), + (AsusVPNC.ON, AsusData.VPNC), + (AsusWLAN.OFF, AsusData.WLAN), + # None + (None, None), + # Wrong types + (1, None), + ("string", None), + ], +) +def test_get_datatype(state, expected): + """Test get_datatype.""" + + # Try to get the datatype + with mock.patch("asusrouter.modules.state.AsusStateMap", mock_state_map): + result = get_datatype(state) + + # Check the result + assert result == expected + + +@pytest.mark.parametrize( + "state, expected", + [ + # Existing values of AsusState + (AsusState.SYSTEM, AsusData.SYSTEM.value), + (AsusState.VPNC, AsusData.VPNC.value), + (AsusState.WLAN, AsusData.WLAN.value), + # None + (None, None), + # Wrong types + (1, None), + ("string", None), + ], +) +def test_get_module_name(state, expected): + """Test _get_module_name.""" + + # Mock the get_datatype function + with mock.patch( + "asusrouter.modules.state.get_datatype", lambda s: mock_state_map.get(s, None) + ): + result = _get_module_name(state) + + # Check the result + assert result == expected + + +@pytest.mark.parametrize( + "state, module_name, expected_module, import_result, import_exception, expected", + [ + # Existing values of AsusState + (AsusState.SYSTEM, "system", "system", mock.MagicMock(), None, "mock_module"), + (AsusState.VPNC, "vpnc", "vpnc", mock.MagicMock(), None, "mock_module"), + (AsusState.WLAN, "wlan", "wlan", mock.MagicMock(), None, "mock_module"), + ( + AsusState.WIREGUARD_CLIENT, + "wireguard_client", + "wireguard", + mock.MagicMock(), + None, + "mock_module", + ), + # ModuleNotFoundError + ( + AsusState.SYSTEM, + "system", + "system", + None, + ModuleNotFoundError, + None, + ), + # None + (None, None, None, None, None, None), + # Wrong types + (1, None, None, None, None, None), + ("string", None, None, None, None, None), + ], +) +def test_get_module( + state, module_name, expected_module, import_result, import_exception, expected +): + """Test _get_module.""" + + # Mock the _get_module_name function + with mock.patch( + "asusrouter.modules.state._get_module_name", return_value=module_name + ): + # Mock the importlib.import_module function + with mock.patch("importlib.import_module") as mock_import: + mock_import.return_value = import_result + mock_import.side_effect = import_exception + result = _get_module(state) + if expected_module and import_result is not None: + mock_import.assert_called_once_with( + f"asusrouter.modules.{expected_module}" + ) + + # Check the result + assert (result is not None) == (expected is not None) + + +@pytest.mark.parametrize( + "module, method, expected", + [ + # Existing method + (MockModule(), "set_state", True), + # Non-existing method + (MockModule(), "non_existing_method", False), + ], +) +def test_has_method(module, method, expected): + """Test _has_method.""" + + result = _has_method(module, method) + + assert result == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "has_method, require_state, require_identity, expected", + [ + (True, False, False, True), + (True, True, False, True), + (True, False, True, True), + (False, False, False, False), + ], +) +async def test_set_state(has_method, require_state, require_identity, expected): + """Test set_state.""" + + # Mock the _get_module function + with mock.patch( + "asusrouter.modules.state._get_module", + return_value=MockModule(require_state, require_identity), + ): + # Mock the _has_method function + with mock.patch( + "asusrouter.modules.state._has_method", return_value=has_method + ): + result = await set_state(mock.AsyncMock(), AsusState.SYSTEM) + + assert result == expected + + +@pytest.mark.parametrize( + "state, datatype", + [ + (AsusState.VPNC, AsusData.VPNC), + (AsusState.VPNC, None), + ], +) +def test_save_state(state, datatype): + """Test save_state.""" + + # Mock the get_datatype function + with mock.patch("asusrouter.modules.state.get_datatype", return_value=datatype): + # Mock the AsusDataState objects + library = {AsusData.VPNC: MockModule()} + + # Call the function + save_state(state, library) # type: ignore + + # Check the calls to the AsusDataState methods + if datatype is not None: + library[datatype].update_state.assert_called_once_with(state, None) + library[datatype].offset_time.assert_called_once_with(None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "states, has_method, expected", + [ + ([AsusState.SYSTEM], True, None), + (AsusState.SYSTEM, True, None), # Single value, not a list + ([AsusState.SYSTEM], False, None), + (None, False, None), + ], +) +async def test_keep_state(states, has_method, expected): + """Test keep_state.""" + + # Mock the _get_module function + with mock.patch("asusrouter.modules.state._get_module", return_value=MockModule()): + # Mock the _has_method function + with mock.patch( + "asusrouter.modules.state._has_method", return_value=has_method + ): + result = await keep_state(mock.AsyncMock(), states) + + assert result == expected