Skip to content

Commit

Permalink
refactor: more cleanup/linting
Browse files Browse the repository at this point in the history
  • Loading branch information
MVladislav committed Jan 12, 2025
1 parent 6ddfefb commit a024b17
Show file tree
Hide file tree
Showing 55 changed files with 552 additions and 567 deletions.
9 changes: 3 additions & 6 deletions amqtt/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,13 @@ async def _feed_buffer(self, n: int = 1) -> None:
:param n: if given, feed buffer until it contains at least n bytes.
"""
buffer = bytearray(self._stream.read())
message: str | bytes | None = None
while len(buffer) < n:
try:
with suppress(ConnectionClosed):
message = await self._protocol.recv()
except ConnectionClosed:
message = None
if message is None:
break
if not isinstance(message, bytes):
msg = "message must be bytes"
raise TypeError(msg)
message = message.encode("utf-8") if isinstance(message, str) else message
buffer.extend(message)
self._stream = io.BytesIO(buffer)

Expand Down
96 changes: 58 additions & 38 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from asyncio import CancelledError, futures
from collections import deque
from collections.abc import Generator
from enum import Enum
from functools import partial
import logging
import re
import ssl
from typing import Any
from typing import Any, ClassVar

from transitions import Machine, MachineError
import websockets.asyncio.server
Expand All @@ -20,18 +21,24 @@
WebSocketsWriter,
WriterAdapter,
)
from amqtt.errors import AMQTTException, BrokerException, MQTTException, NoDataException
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, Session
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id

from .plugins.manager import BaseContext, PluginManager

_defaults: dict[str, int | bool | dict[Any, Any]] = {
type _CONFIG_LISTENER = dict[str, int | bool | dict[str, Any]]
type _BROADCAST = dict[str, Session | str | bytes | int | None]

_defaults: _CONFIG_LISTENER = {
"timeout-disconnect-delay": 2,
"auth": {"allow-anonymous": True, "password-file": None},
}


AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80

EVENT_BROKER_PRE_START = "broker_pre_start"
EVENT_BROKER_POST_START = "broker_post_start"
EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown"
Expand Down Expand Up @@ -59,7 +66,12 @@ def __init__(self, source_session: Session | None, topic: str, data: bytes, qos:


class Server:
def __init__(self, listener_name: str, server_instance: Any, max_connections: int = -1) -> None:
def __init__(
self,
listener_name: str,
server_instance: asyncio.Server | websockets.asyncio.server.Server,
max_connections: int = -1,
) -> None:
self.logger = logging.getLogger(__name__)
self.instance = server_instance
self.conn_count = 0
Expand Down Expand Up @@ -100,7 +112,7 @@ class BrokerContext(BaseContext):

def __init__(self, broker: "Broker") -> None:
super().__init__()
self.config: dict[str, int | bool | dict[Any, Any]] | None = None
self.config: _CONFIG_LISTENER | None = None
self._broker_instance = broker

async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
Expand All @@ -110,17 +122,17 @@ def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | No
self._broker_instance.retain_message(None, topic_name, data, qos)

@property
def sessions(self) -> Any:
for session in self._broker_instance._sessions.values(): # noqa: SLF001
def sessions(self) -> Generator[Session]:
for session in self._broker_instance._sessions.values():
yield session[0]

@property
def retained_messages(self) -> dict[Any, Any]:
return self._broker_instance._retained_messages # noqa: SLF001
def retained_messages(self) -> dict[str, RetainedApplicationMessage]:
return self._broker_instance._retained_messages

@property
def subscriptions(self) -> dict[Any, Any]:
return self._broker_instance._subscriptions # noqa: SLF001
def subscriptions(self) -> dict[str, list[tuple[Session, int]]]:
return self._broker_instance._subscriptions


class Broker:
Expand All @@ -132,7 +144,7 @@ class Broker:
"""

states = [
states: ClassVar[list[str]] = [
"new",
"starting",
"started",
Expand All @@ -145,7 +157,7 @@ class Broker:

def __init__(
self,
config: dict[str, Any] | None = None,
config: _CONFIG_LISTENER | None = None,
loop: asyncio.AbstractEventLoop | None = None,
plugin_namespace: str | None = None,
) -> None:
Expand All @@ -161,7 +173,7 @@ def __init__(
self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {}
self._subscriptions: dict[str, list[tuple[Session, int]]] = {}
self._retained_messages: dict[str, RetainedApplicationMessage] = {}
self._broadcast_queue: asyncio.Queue[Any] = asyncio.Queue()
self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()

self._broadcast_task: asyncio.Task[Any] | None = None
self._broadcast_shutdown_waiter: asyncio.Future[Any] = futures.Future()
Expand All @@ -172,18 +184,25 @@ def __init__(
namespace = plugin_namespace or "amqtt.broker.plugins"
self.plugins_manager = PluginManager(namespace, context, self._loop)

def _build_listeners_config(self, broker_config: dict[str, Any]) -> None:
def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
self.listeners_config = {}
try:
listeners_config = broker_config["listeners"]
defaults = listeners_config["default"]
listeners_config = broker_config.get("listeners")
if not isinstance(listeners_config, dict):
msg = "Listener config not found or invalid"
raise BrokerError(msg)
defaults = listeners_config.get("default")
if defaults is None:
msg = "Listener config has not default included or is invalid"
raise BrokerError(msg)

for listener_name, listener_conf in listeners_config.items():
config = defaults.copy()
config.update(listener_conf)
self.listeners_config[listener_name] = config
except KeyError as ke:
msg = f"Listener config not found or invalid: {ke}"
raise BrokerException(msg) from ke
raise BrokerError(msg) from ke

def _init_states(self) -> None:
self.transitions = Machine(states=Broker.states, initial="new")
Expand Down Expand Up @@ -212,7 +231,7 @@ async def start(self) -> None:
# Backwards compat: MachineError is raised by transitions < 0.5.0.
self.logger.warning(f"[WARN-0001] Invalid method call at this moment: {exc}")
msg = f"Broker instance can't be started: {exc}"
raise BrokerException(msg) from exc
raise BrokerError(msg) from exc

await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
try:
Expand Down Expand Up @@ -244,18 +263,18 @@ async def start(self) -> None:
sc.verify_mode = ssl.CERT_OPTIONAL
except KeyError as ke:
msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}"
raise BrokerException(msg) from ke
raise BrokerError(msg) from ke
except FileNotFoundError as fnfe:
msg = "Can't read cert files '{}' or '{}' : {}".format(listener["certfile"], listener["keyfile"], fnfe)
raise BrokerException(msg) from fnfe
raise BrokerError(msg) from fnfe

address, s_port = listener["bind"].split(":")
port = 0
try:
port = int(s_port)
except ValueError as e:
msg = "Invalid port value in bind value: {}".format(listener["bind"])
raise BrokerException(msg) from e
raise BrokerError(msg) from e

instance: asyncio.Server | websockets.asyncio.server.Server | None = None
if listener["type"] == "tcp":
Expand Down Expand Up @@ -292,7 +311,7 @@ async def start(self) -> None:
self.logger.exception("Broker startup failed")
self.transitions.starting_fail()
msg = f"Broker instance can't be started: {e}"
raise BrokerException(msg) from e
raise BrokerError(msg) from e

async def shutdown(self) -> None:
"""Stop broker instance.
Expand Down Expand Up @@ -330,12 +349,12 @@ async def ws_connected(self, websocket: ServerConnection, listener_name: str) ->
async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))

async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: # noqa: C901, PLR0915, PLR0912
async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:
# Wait for connection available on listener
server = self._servers.get(listener_name, None)
if not server:
msg = f"Invalid listener name '{listener_name}'"
raise BrokerException(msg)
raise BrokerError(msg)
await server.acquire_connection()

remote_info = writer.get_peer_info()
Expand All @@ -346,7 +365,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ
# Wait for first packet and expect a CONNECT
try:
handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager)
except AMQTTException as exc:
except AMQTTError as exc:
self.logger.warning(
f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:"
f"Can't read first packet an CONNECT: {exc}",
Expand All @@ -355,15 +374,15 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ
self.logger.debug("Connection closed")
server.release_connection()
return
except MQTTException:
except MQTTError:
self.logger.exception(
f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}",
)
await writer.close()
server.release_connection()
self.logger.debug("Connection closed")
return
except NoDataException as ne:
except NoDataError as ne:
self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails
server.release_connection()
return
Expand All @@ -385,7 +404,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ

if client_session.client_id is None:
msg = "Client ID was not correct created/set."
raise BrokerException(msg)
raise BrokerError(msg)

timeout_disconnect_delay = self.config.get("timeout-disconnect-delay")
if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int):
Expand Down Expand Up @@ -493,9 +512,10 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ
return_codes = [
await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics
]

await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)
for index, subscription in enumerate(subscriptions.topics):
if return_codes[index] != 0x80:
if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED:
await self.plugins_manager.fire_event(
EVENT_BROKER_CLIENT_SUBSCRIBED,
client_id=client_session.client_id,
Expand Down Expand Up @@ -577,7 +597,7 @@ async def _stop_handler(self, handler: BrokerProtocolHandler) -> None:
except Exception:
self.logger.exception("Failed to stop handler")

async def authenticate(self, session: Session, listener: Any) -> bool: # noqa: ARG002
async def authenticate(self, session: Session, _: dict[str, Any]) -> bool:
"""Call the authenticate method on registered plugins to test user authentication.
User is considered authenticated if all plugins called returns True.
Expand Down Expand Up @@ -708,7 +728,7 @@ def _del_all_subscriptions(self, session: Session) -> None:
:param session:
:return:
"""
filter_queue: deque[Any] = deque()
filter_queue: deque[str] = deque()
for topic in self._subscriptions:
if self._del_subscription(topic, session):
filter_queue.append(topic)
Expand All @@ -725,7 +745,7 @@ def matches(self, topic: str, a_filter: str) -> bool:
return bool(match_pattern.fullmatch(topic))

async def _broadcast_loop(self) -> None:
running_tasks: deque[Any] = deque()
running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = deque()
try:
while True:
while running_tasks and running_tasks[0].done():
Expand Down Expand Up @@ -758,7 +778,7 @@ async def _broadcast_loop(self) -> None:
if running_tasks:
await asyncio.gather(*running_tasks)

async def _run_broadcast(self, running_tasks: deque[Any]) -> None:
async def _run_broadcast(self, running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]]) -> None:
broadcast = await self._broadcast_queue.get()

if self.logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -800,7 +820,7 @@ async def _run_broadcast(self, running_tasks: deque[Any]) -> None:
)
running_tasks.append(task)

async def _retain_broadcast_message(self, broadcast: dict[Any, Any], qos: int, target_session: Session) -> None:
async def _retain_broadcast_message(self, broadcast: dict[str, Any], qos: int, target_session: Session) -> None:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
f"retaining application message from {format_client_message(session=broadcast['session'])}"
Expand All @@ -818,7 +838,7 @@ async def _shutdown_broadcast_loop(self) -> None:
self._broadcast_shutdown_waiter.set_result(True)
try:
await asyncio.wait_for(self._broadcast_task, timeout=30)
except BaseException as e:
except TimeoutError as e:
self.logger.warning(f"Failed to cleanly shutdown broadcast loop: {e}")

if not self._broadcast_queue.empty():
Expand All @@ -831,7 +851,7 @@ async def _broadcast_message(
data: bytes | None,
force_qos: int | None = None,
) -> None:
broadcast: dict[str, Session | str | bytes | int | None] = {"session": session, "topic": topic, "data": data}
broadcast: _BROADCAST = {"session": session, "topic": topic, "data": data}
if force_qos is not None:
broadcast["qos"] = force_qos
await self._broadcast_queue.put(broadcast)
Expand Down
Loading

0 comments on commit a024b17

Please sign in to comment.