From 2f634a6975bcacc37753fb0b6633dc4c3f355ebe Mon Sep 17 00:00:00 2001 From: Yevhenii Vaskivskyi Date: Sat, 2 Dec 2023 23:15:41 +0100 Subject: [PATCH] Test and refactor `openvpn` module (#413) --- asusrouter/modules/openvpn.py | 40 +++++----- tests/modules/test_openvpn.py | 133 ++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 19 deletions(-) create mode 100644 tests/modules/test_openvpn.py diff --git a/asusrouter/modules/openvpn.py b/asusrouter/modules/openvpn.py index 60ca87a..feae0a7 100644 --- a/asusrouter/modules/openvpn.py +++ b/asusrouter/modules/openvpn.py @@ -4,10 +4,10 @@ import logging from enum import IntEnum -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable from asusrouter.modules.firmware import Firmware -from asusrouter.modules.identity import AsusDevice +from asusrouter.tools.converters import get_arguments _LOGGER = logging.getLogger(__name__) @@ -43,22 +43,27 @@ class AsusOVPNServer(IntEnum): async def set_state( callback: Callable[..., Awaitable[bool]], state: AsusOVPNClient | AsusOVPNServer, - arguments: Optional[dict[str, Any]] = None, - expect_modify: bool = False, - identity: Optional[AsusDevice] = None, + **kwargs: Any, ) -> bool: """Set the OpenVPN state.""" - # Check if arguments are available - if not arguments: - arguments = {} + # Check if state is available + if not isinstance(state, (AsusOVPNClient, AsusOVPNServer)) or not state.value in ( + 0, + 1, + ): + _LOGGER.debug("No state found in arguments") + return False + + # Get the arguments + vpn_id, identity = get_arguments(("id", "identity"), **kwargs) - # Get the id from arguments - vpn_id = arguments.get("id") if not vpn_id: _LOGGER.debug("No VPN id found in arguments") return False + service_arguments = {"id": vpn_id} + service_map: dict[Any, str] # Get the correct service call @@ -83,12 +88,9 @@ async def set_state( AsusOVPNServer.OFF: "stop_openvpnd;restart_samba;restart_dnsmasq;", } service = service_map.get(state) if isinstance(state, AsusOVPNServer) else None - arguments = { - "VPNServer_enable": "1" if state == AsusOVPNServer.ON else "0", - } - - # Add `id` to arguments for proper state save - arguments["id"] = vpn_id + service_arguments["VPNServer_enable"] = ( + "1" if state == AsusOVPNServer.ON else "0" + ) if not service: _LOGGER.debug("Unknown state %s", state) @@ -97,13 +99,13 @@ async def set_state( _LOGGER.debug( "Triggering state set with parameters: service=%s, arguments=%s", service, - arguments, + service_arguments, ) # Run the service return await callback( service=service, - arguments=arguments, + arguments=service_arguments, apply=True, - expect_modify=expect_modify, + expect_modify=kwargs.get("expect_modify", False), ) diff --git a/tests/modules/test_openvpn.py b/tests/modules/test_openvpn.py new file mode 100644 index 0000000..66e26f7 --- /dev/null +++ b/tests/modules/test_openvpn.py @@ -0,0 +1,133 @@ +"""Tests for the openvpn module.""" + +from unittest.mock import AsyncMock + +import pytest + +from asusrouter.modules.firmware import Firmware +from asusrouter.modules.identity import AsusDevice +from asusrouter.modules.openvpn import AsusOVPNClient, AsusOVPNServer, set_state + +FW_MAJOR = "3.0.0.4" +FW_MINOR_OLD = 386 +FW_MINOR_NEW = 388 + +identity_mock = { + "merlin_new": AsusDevice(merlin=True, firmware=Firmware(FW_MAJOR, FW_MINOR_NEW, 0)), + "merlin_old": AsusDevice(merlin=True, firmware=Firmware(FW_MAJOR, FW_MINOR_OLD, 0)), + "stock_new": AsusDevice(merlin=False, firmware=Firmware(FW_MAJOR, FW_MINOR_NEW, 0)), + "stock_old": AsusDevice(merlin=False, firmware=Firmware(FW_MAJOR, FW_MINOR_OLD, 0)), +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "state, vpn_id, identity, expect_modify, expect_call, expected_args, expected_service", + [ + # Correct states + (AsusOVPNClient.ON, 1, "merlin_new", True, True, {"id": 1}, "start_vpnclient1"), + ( + AsusOVPNClient.OFF, + 1, + "merlin_new", + False, + True, + {"id": 1}, + "stop_vpnclient1", + ), + (AsusOVPNServer.ON, 1, "merlin_new", True, True, {"id": 1}, "start_vpnserver1"), + ( + AsusOVPNServer.OFF, + 1, + "merlin_new", + False, + True, + {"id": 1}, + "stop_vpnserver1", + ), + # Different identity - merlin old + (AsusOVPNClient.ON, 1, "merlin_old", True, True, {"id": 1}, "start_vpnclient1"), + ( + AsusOVPNClient.OFF, + 1, + "merlin_old", + False, + True, + {"id": 1}, + "stop_vpnclient1", + ), + # Different identity - stock old + (AsusOVPNClient.ON, 1, "stock_old", True, True, {"id": 1}, "start_vpnclient1"), + ( + AsusOVPNClient.OFF, + 1, + "stock_old", + False, + True, + {"id": 1}, + "stop_vpnclient1", + ), + # Get to modern API - stock new + ( + AsusOVPNServer.ON, + 1, + "stock_new", + True, + True, + {"id": 1, "VPNServer_enable": "1"}, + "restart_openvpnd;restart_chpass;restart_samba;restart_dnsmasq;", + ), + ( + AsusOVPNServer.OFF, + 1, + "stock_new", + False, + True, + {"id": 1, "VPNServer_enable": "0"}, + "stop_openvpnd;restart_samba;restart_dnsmasq;", + ), + # Modern API with unknown state - this should not happen but is handled + ( + AsusOVPNClient.ON, + 1, + "stock_new", + True, + False, + {}, + None, + ), + # Wrong states + (AsusOVPNClient.UNKNOWN, 1, "merlin_new", False, False, {}, None), + (None, 1, "merlin_new", False, False, {}, None), + # No ID + (AsusOVPNClient.ON, None, "merlin_new", True, False, {}, None), + # No identity - should use legacy service + (AsusOVPNClient.ON, 1, None, True, True, {"id": 1}, "start_vpnclient1"), + ], +) +async def test_set_state( + state, vpn_id, identity, expect_modify, expect_call, expected_args, expected_service +): + """Test set_state.""" + + # Create a mock callback function + callback = AsyncMock() + + # Compile the kwargs + kwargs = {"id": vpn_id, "identity": identity_mock[identity] if identity else None} + + # Call the set_state function + await set_state( + callback=callback, state=state, expect_modify=expect_modify, **kwargs + ) + + # Check if the callback function was called + if expect_call: + callback.assert_called_once_with( + service=expected_service, + arguments=expected_args, + apply=True, + expect_modify=expect_modify, + ) + else: + callback.assert_not_called()