Skip to content

Commit

Permalink
Test and refactor openvpn module (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaskivskyi authored Dec 2, 2023
1 parent 45cbe0e commit 2f634a6
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 19 deletions.
40 changes: 21 additions & 19 deletions asusrouter/modules/openvpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import logging
from enum import IntEnum
from typing import Any, Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable

from asusrouter.modules.firmware import Firmware
from asusrouter.modules.identity import AsusDevice
from asusrouter.tools.converters import get_arguments

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,22 +43,27 @@ class AsusOVPNServer(IntEnum):
async def set_state(
callback: Callable[..., Awaitable[bool]],
state: AsusOVPNClient | AsusOVPNServer,
arguments: Optional[dict[str, Any]] = None,
expect_modify: bool = False,
identity: Optional[AsusDevice] = None,
**kwargs: Any,
) -> bool:
"""Set the OpenVPN state."""

# Check if arguments are available
if not arguments:
arguments = {}
# Check if state is available
if not isinstance(state, (AsusOVPNClient, AsusOVPNServer)) or not state.value in (
0,
1,
):
_LOGGER.debug("No state found in arguments")
return False

# Get the arguments
vpn_id, identity = get_arguments(("id", "identity"), **kwargs)

# Get the id from arguments
vpn_id = arguments.get("id")
if not vpn_id:
_LOGGER.debug("No VPN id found in arguments")
return False

service_arguments = {"id": vpn_id}

service_map: dict[Any, str]

# Get the correct service call
Expand All @@ -83,12 +88,9 @@ async def set_state(
AsusOVPNServer.OFF: "stop_openvpnd;restart_samba;restart_dnsmasq;",
}
service = service_map.get(state) if isinstance(state, AsusOVPNServer) else None
arguments = {
"VPNServer_enable": "1" if state == AsusOVPNServer.ON else "0",
}

# Add `id` to arguments for proper state save
arguments["id"] = vpn_id
service_arguments["VPNServer_enable"] = (
"1" if state == AsusOVPNServer.ON else "0"
)

if not service:
_LOGGER.debug("Unknown state %s", state)
Expand All @@ -97,13 +99,13 @@ async def set_state(
_LOGGER.debug(
"Triggering state set with parameters: service=%s, arguments=%s",
service,
arguments,
service_arguments,
)

# Run the service
return await callback(
service=service,
arguments=arguments,
arguments=service_arguments,
apply=True,
expect_modify=expect_modify,
expect_modify=kwargs.get("expect_modify", False),
)
133 changes: 133 additions & 0 deletions tests/modules/test_openvpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Tests for the openvpn module."""

from unittest.mock import AsyncMock

import pytest

from asusrouter.modules.firmware import Firmware
from asusrouter.modules.identity import AsusDevice
from asusrouter.modules.openvpn import AsusOVPNClient, AsusOVPNServer, set_state

FW_MAJOR = "3.0.0.4"
FW_MINOR_OLD = 386
FW_MINOR_NEW = 388

identity_mock = {
"merlin_new": AsusDevice(merlin=True, firmware=Firmware(FW_MAJOR, FW_MINOR_NEW, 0)),
"merlin_old": AsusDevice(merlin=True, firmware=Firmware(FW_MAJOR, FW_MINOR_OLD, 0)),
"stock_new": AsusDevice(merlin=False, firmware=Firmware(FW_MAJOR, FW_MINOR_NEW, 0)),
"stock_old": AsusDevice(merlin=False, firmware=Firmware(FW_MAJOR, FW_MINOR_OLD, 0)),
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"state, vpn_id, identity, expect_modify, expect_call, expected_args, expected_service",
[
# Correct states
(AsusOVPNClient.ON, 1, "merlin_new", True, True, {"id": 1}, "start_vpnclient1"),
(
AsusOVPNClient.OFF,
1,
"merlin_new",
False,
True,
{"id": 1},
"stop_vpnclient1",
),
(AsusOVPNServer.ON, 1, "merlin_new", True, True, {"id": 1}, "start_vpnserver1"),
(
AsusOVPNServer.OFF,
1,
"merlin_new",
False,
True,
{"id": 1},
"stop_vpnserver1",
),
# Different identity - merlin old
(AsusOVPNClient.ON, 1, "merlin_old", True, True, {"id": 1}, "start_vpnclient1"),
(
AsusOVPNClient.OFF,
1,
"merlin_old",
False,
True,
{"id": 1},
"stop_vpnclient1",
),
# Different identity - stock old
(AsusOVPNClient.ON, 1, "stock_old", True, True, {"id": 1}, "start_vpnclient1"),
(
AsusOVPNClient.OFF,
1,
"stock_old",
False,
True,
{"id": 1},
"stop_vpnclient1",
),
# Get to modern API - stock new
(
AsusOVPNServer.ON,
1,
"stock_new",
True,
True,
{"id": 1, "VPNServer_enable": "1"},
"restart_openvpnd;restart_chpass;restart_samba;restart_dnsmasq;",
),
(
AsusOVPNServer.OFF,
1,
"stock_new",
False,
True,
{"id": 1, "VPNServer_enable": "0"},
"stop_openvpnd;restart_samba;restart_dnsmasq;",
),
# Modern API with unknown state - this should not happen but is handled
(
AsusOVPNClient.ON,
1,
"stock_new",
True,
False,
{},
None,
),
# Wrong states
(AsusOVPNClient.UNKNOWN, 1, "merlin_new", False, False, {}, None),
(None, 1, "merlin_new", False, False, {}, None),
# No ID
(AsusOVPNClient.ON, None, "merlin_new", True, False, {}, None),
# No identity - should use legacy service
(AsusOVPNClient.ON, 1, None, True, True, {"id": 1}, "start_vpnclient1"),
],
)
async def test_set_state(
state, vpn_id, identity, expect_modify, expect_call, expected_args, expected_service
):
"""Test set_state."""

# Create a mock callback function
callback = AsyncMock()

# Compile the kwargs
kwargs = {"id": vpn_id, "identity": identity_mock[identity] if identity else None}

# Call the set_state function
await set_state(
callback=callback, state=state, expect_modify=expect_modify, **kwargs
)

# Check if the callback function was called
if expect_call:
callback.assert_called_once_with(
service=expected_service,
arguments=expected_args,
apply=True,
expect_modify=expect_modify,
)
else:
callback.assert_not_called()

0 comments on commit 2f634a6

Please sign in to comment.