diff --git a/asusrouter/modules/port_forwarding.py b/asusrouter/modules/port_forwarding.py index cf77b69..3448dd6 100644 --- a/asusrouter/modules/port_forwarding.py +++ b/asusrouter/modules/port_forwarding.py @@ -1,8 +1,11 @@ """Port forwarding module.""" +import logging from dataclasses import dataclass from enum import IntEnum -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable + +_LOGGER = logging.getLogger(__name__) KEY_PORT_FORWARDING_LIST = "vts_rulelist" KEY_PORT_FORWARDING_STATE = "vts_enable_x" @@ -31,22 +34,24 @@ class AsusPortForwarding(IntEnum): async def set_state( callback: Callable[..., Awaitable[bool]], state: AsusPortForwarding, - arguments: Optional[dict[str, Any]] = None, - expect_modify: bool = False, - _: Optional[dict[Any, Any]] = None, + **kwargs: Any, ) -> bool: """Set the parental control state.""" - # Check if arguments are available - if not arguments: - arguments = {} + # Check if state is available and valid + if not isinstance(state, AsusPortForwarding) or not state.value in (0, 1): + _LOGGER.debug("No state found in arguments") + return False - arguments[KEY_PORT_FORWARDING_STATE] = 1 if state == AsusPortForwarding.ON else 0 + arguments = {KEY_PORT_FORWARDING_STATE: 1 if state == AsusPortForwarding.ON else 0} # Get the correct service call service = "restart_firewall" # Call the service return await callback( - service, arguments=arguments, apply=True, expect_modify=expect_modify + service=service, + arguments=arguments, + apply=True, + expect_modify=kwargs.get("expect_modify", False), ) diff --git a/tests/modules/test_port_forwarding.py b/tests/modules/test_port_forwarding.py new file mode 100644 index 0000000..fc0c9f4 --- /dev/null +++ b/tests/modules/test_port_forwarding.py @@ -0,0 +1,46 @@ +"""Tests for the port forwarding module.""" + +from unittest.mock import AsyncMock + +import pytest + +from asusrouter.modules.port_forwarding import ( + KEY_PORT_FORWARDING_STATE, + AsusPortForwarding, + set_state, +) + +async_callback = AsyncMock() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "state, expect_modify, expect_call, expected_args", + [ + # Correct states + (AsusPortForwarding.ON, True, True, {KEY_PORT_FORWARDING_STATE: 1}), + (AsusPortForwarding.OFF, False, True, {KEY_PORT_FORWARDING_STATE: 0}), + # Wrong states + (AsusPortForwarding.UNKNOWN, False, False, {}), + (None, False, False, {}), + ], +) +async def test_set_state(state, expect_modify, expect_call, expected_args): + """Test set_state.""" + + # Call the set_state function + await set_state(callback=async_callback, state=state, expect_modify=expect_modify) + + # Check if the callback function was called + if expect_call: + async_callback.assert_called_once_with( + service="restart_firewall", + arguments=expected_args, + apply=True, + expect_modify=expect_modify, + ) + else: + async_callback.assert_not_called() + + # Reset the mock callback function + async_callback.reset_mock()