Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and refactor port_forwarding module #412

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions asusrouter/modules/port_forwarding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Port forwarding module."""

import logging
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable

_LOGGER = logging.getLogger(__name__)

KEY_PORT_FORWARDING_LIST = "vts_rulelist"
KEY_PORT_FORWARDING_STATE = "vts_enable_x"
Expand Down Expand Up @@ -31,22 +34,24 @@ class AsusPortForwarding(IntEnum):
async def set_state(
callback: Callable[..., Awaitable[bool]],
state: AsusPortForwarding,
arguments: Optional[dict[str, Any]] = None,
expect_modify: bool = False,
_: Optional[dict[Any, Any]] = None,
**kwargs: Any,
) -> bool:
"""Set the parental control state."""

# Check if arguments are available
if not arguments:
arguments = {}
# Check if state is available and valid
if not isinstance(state, AsusPortForwarding) or not state.value in (0, 1):
_LOGGER.debug("No state found in arguments")
return False

arguments[KEY_PORT_FORWARDING_STATE] = 1 if state == AsusPortForwarding.ON else 0
arguments = {KEY_PORT_FORWARDING_STATE: 1 if state == AsusPortForwarding.ON else 0}

# Get the correct service call
service = "restart_firewall"

# Call the service
return await callback(
service, arguments=arguments, apply=True, expect_modify=expect_modify
service=service,
arguments=arguments,
apply=True,
expect_modify=kwargs.get("expect_modify", False),
)
46 changes: 46 additions & 0 deletions tests/modules/test_port_forwarding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Tests for the port forwarding module."""

from unittest.mock import AsyncMock

import pytest

from asusrouter.modules.port_forwarding import (
KEY_PORT_FORWARDING_STATE,
AsusPortForwarding,
set_state,
)

async_callback = AsyncMock()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"state, expect_modify, expect_call, expected_args",
[
# Correct states
(AsusPortForwarding.ON, True, True, {KEY_PORT_FORWARDING_STATE: 1}),
(AsusPortForwarding.OFF, False, True, {KEY_PORT_FORWARDING_STATE: 0}),
# Wrong states
(AsusPortForwarding.UNKNOWN, False, False, {}),
(None, False, False, {}),
],
)
async def test_set_state(state, expect_modify, expect_call, expected_args):
"""Test set_state."""

# Call the set_state function
await set_state(callback=async_callback, state=state, expect_modify=expect_modify)

# Check if the callback function was called
if expect_call:
async_callback.assert_called_once_with(
service="restart_firewall",
arguments=expected_args,
apply=True,
expect_modify=expect_modify,
)
else:
async_callback.assert_not_called()

# Reset the mock callback function
async_callback.reset_mock()