Skip to content

Commit

Permalink
Test and refactor state module (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaskivskyi authored Nov 26, 2023
1 parent c51501a commit 04063b1
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 35 deletions.
12 changes: 6 additions & 6 deletions asusrouter/asusrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,12 +706,12 @@ async def async_set_state(
await self._async_check_state_dependency(state)

result = await set_state(
self.async_run_service,
state,
arguments,
expect_modify,
self._state,
self._identity,
callback=self.async_run_service,
state=state,
arguments=arguments,
expect_modify=expect_modify,
router_state=self._state,
identity=self._identity,
)

if result is True:
Expand Down
54 changes: 25 additions & 29 deletions asusrouter/modules/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import logging
from enum import Enum
from types import ModuleType
from typing import Any, Awaitable, Callable, Optional, Tuple
from typing import Any, Awaitable, Callable, Optional

from asusrouter.modules.connection import ConnectionState
from asusrouter.modules.data import AsusData, AsusDataState
from asusrouter.modules.identity import AsusDevice
from asusrouter.modules.parental_control import AsusParentalControl
from asusrouter.modules.port_forwarding import AsusPortForwarding
from asusrouter.modules.system import AsusSystem
from asusrouter.modules.vpnc import AsusVPNC
from asusrouter.modules.wireguard import AsusWireGuardClient, AsusWireGuardServer
from asusrouter.modules.wlan import AsusWLAN
from asusrouter.tools.converters import get_enum_key_by_value

from .led import AsusLED
from .openvpn import AsusOVPNClient, AsusOVPNServer
Expand Down Expand Up @@ -69,19 +69,20 @@ class AsusState(Enum):
def add_conditional_state(state: AsusState, data: AsusData) -> None:
"""A callback to add / change AsusStateMap."""

if not isinstance(state, AsusState) or not isinstance(data, AsusData):
_LOGGER.debug("Invalid state or data type: %s -> %s", state, data)
return

AsusStateMap[state] = data
_LOGGER.debug("Added conditional state rule: %s -> %s", state, data)


def get_datatype(state: Optional[Any]) -> Optional[AsusData]:
"""Get the datatype."""

asus_state = AsusState(type(state))
asus_state = get_enum_key_by_value(AsusState, type(state), default=AsusState.NONE)

if state is not None:
return AsusStateMap.get(asus_state)

return None
return AsusStateMap.get(asus_state)


def _get_module_name(state: AsusState) -> Optional[str]:
Expand Down Expand Up @@ -127,10 +128,7 @@ def _has_method(module: ModuleType, method: str) -> bool:
async def set_state(
callback: Callable[..., Awaitable[bool]],
state: AsusState,
arguments: Optional[dict[str, Any]] = None,
expect_modify: bool = False,
router_state: Optional[dict[AsusData, AsusDataState]] = None,
identity: Optional[AsusDevice] = None,
**kwargs: Any,
) -> bool:
"""Set the state."""

Expand All @@ -140,19 +138,16 @@ async def set_state(
# Process the data if module found
if submodule and _has_method(submodule, "set_state"):
# Determine the extra parameter
extra_param: Optional[dict[AsusData, AsusDataState] | AsusDevice] = None
if getattr(submodule, "REQUIRE_STATE", False):
extra_param = router_state
elif getattr(submodule, "REQUIRE_IDENTITY", False):
extra_param = identity
kwargs["extra_param"] = kwargs.get("router_state")
if getattr(submodule, "REQUIRE_IDENTITY", False):
kwargs["extra_param"] = kwargs.get("identity")

# Call the function with the determined parameters
return await submodule.set_state(
callback=callback,
state=state,
arguments=arguments,
expect_modify=expect_modify,
extra_param=extra_param,
**kwargs,
)

return False
Expand All @@ -168,7 +163,7 @@ def save_state(

# Get the correct data key
datatype = get_datatype(state)
if datatype is None:
if datatype is None or datatype not in library:
return

# Save the state
Expand All @@ -179,22 +174,23 @@ def save_state(
async def keep_state(
callback: Callable[..., Awaitable[Any]],
states: Optional[AsusState | list[AsusState]],
identity: Optional[AsusDevice],
**kwargs: Any,
) -> None:
"""Keep the state."""

if states is None:
return

# Make sure the state is a list
if not isinstance(states, list):
states = [states]
states = [states] if not isinstance(states, list) else states

# Process each state
for state in states:
# Get the module
submodule = _get_module(state)

# Process the data if module found
if submodule and _has_method(submodule, "keep_state"):
await submodule.keep_state(callback, state, identity)
awaitables = [
submodule.keep_state(callback, state, **kwargs)
for state in states
if (submodule := _get_module(state)) and _has_method(submodule, "keep_state")
]

# Execute all awaitables
for awaitable in awaitables:
await awaitable
Loading

0 comments on commit 04063b1

Please sign in to comment.