Skip to content

Commit

Permalink
refactor: checking open pull requests:
Browse files Browse the repository at this point in the history
- Yakifo#119 :: not working in test cases, only comment
- Yakifo#153 :: updated
- Yakifo#61 :: removed try block as described
- Yakifo#72 :: function included
  • Loading branch information
MVladislav committed Jan 12, 2025
1 parent a024b17 commit 40d1214
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 26 deletions.
92 changes: 67 additions & 25 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"auth": {"allow-anonymous": True, "password-file": None},
}

# Default port numbers
DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}

AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80

Expand Down Expand Up @@ -268,10 +270,8 @@ async def start(self) -> None:
msg = "Can't read cert files '{}' or '{}' : {}".format(listener["certfile"], listener["keyfile"], fnfe)
raise BrokerError(msg) from fnfe

address, s_port = listener["bind"].split(":")
port = 0
try:
port = int(s_port)
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = "Invalid port value in bind value: {}".format(listener["bind"])
raise BrokerError(msg) from e
Expand Down Expand Up @@ -674,30 +674,28 @@ def retain_message(
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
del self._retained_messages[topic_name]

# NOTE: issue #61 remove try block
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
try:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
qos_conf = self.config.get("max-qos", qos)
if isinstance(qos_conf, int):
qos = min(qos, qos_conf)
if topic_filter not in self._subscriptions:
self._subscriptions[topic_filter] = []
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
self._subscriptions[topic_filter].append((session, qos))
else:
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
return qos
except KeyError:
topic_filter, qos = subscription
if "#" in topic_filter and not topic_filter.endswith("#"):
# [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
return 0x80
if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE):
return 0x80
qos_conf = self.config.get("max-qos", qos)
if isinstance(qos_conf, int):
qos = min(qos, qos_conf)
if topic_filter not in self._subscriptions:
self._subscriptions[topic_filter] = []
if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]):
self._subscriptions[topic_filter].append((session, qos))
else:
self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}")
return qos

def _del_subscription(self, a_filter: str, session: Session) -> int:
"""Delete a session subscription on a given topic.
Expand Down Expand Up @@ -922,3 +920,47 @@ def _get_handler(self, session: Session) -> BrokerProtocolHandler | None:
if client_id:
return self._sessions.get(client_id, (None, None))[1]
return None

@classmethod
def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]:
"""Split an address:port pair into separate IP address and port. with IPv6 special-case handling.
NOTE: issue #72
- Address can be specified using one of the following methods:
- 1883 - Port number only (listen all interfaces)
- :1883 - Port number only (listen all interfaces)
- 0.0.0.0:1883 - IPv4 address
- [::]:1883 - IPv6 address
- empty string - all interfaces default port
"""

def _parse_port(port_str: str) -> int:
port_str = port_str.removeprefix(":")

if not port_str:
return default_port

return int(port_str)

if port_str.startswith("["): # IPv6 literal
try:
addr_end = port_str.index("]")
except ValueError as e:
msg = "Expecting '[' to be followed by ']'"
raise ValueError(msg) from e

return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :]))

if ":" in port_str:
# Address : port
address, port_str = port_str.rsplit(":", 1)
return (address or None, _parse_port(port_str))

# Address or port
try:
# Port number?
return (None, _parse_port(port_str))
except ValueError:
# Address, default port
return (port_str, default_port)
3 changes: 2 additions & 1 deletion amqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def cancel_tasks() -> None:
while self.client_tasks:
task = self.client_tasks.popleft()
if not task.done():
task.cancel()
# task.set_exception(ClientError("Connection lost"))
task.cancel() # NOTE: issue #153

self.logger.debug("Monitoring broker disconnection")
# Wait for disconnection from broker (like connection lost)
Expand Down
1 change: 1 addition & 0 deletions amqtt/mqtt/protocol/client_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ async def handle_connection_closed(self) -> None:
self.logger.debug("Broker closed connection")
if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
self._disconnect_waiter.set_result(None)
# await self.stop() # NOTE: issue #119

async def wait_disconnect(self) -> None:
if self._disconnect_waiter is not None:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EVENT_BROKER_POST_START,
EVENT_BROKER_PRE_SHUTDOWN,
EVENT_BROKER_PRE_START,
Broker,
)
from amqtt.client import MQTTClient
from amqtt.errors import ConnectError
Expand Down Expand Up @@ -44,6 +45,23 @@ async def async_magic():
MagicMock.__await__ = lambda _: async_magic().__await__()


@pytest.mark.parametrize(
"input_str, output_addr, output_port",
[
("1234", None, 1234),
(":1234", None, 1234),
("0.0.0.0:1234", "0.0.0.0", 1234),
("[::]:1234", "[::]", 1234),
("0.0.0.0", "0.0.0.0", 5678),
("[::]", "[::]", 5678),
("localhost", "localhost", 5678),
("localhost:1234", "localhost", 1234),
],
)
def test_split_bindaddr_port(input_str, output_addr, output_port):
assert Broker._split_bindaddr_port(input_str, 5678) == (output_addr, output_port)


@pytest.mark.asyncio
async def test_start_stop(broker, mock_plugin_manager):
mock_plugin_manager.assert_has_calls(
Expand Down

0 comments on commit 40d1214

Please sign in to comment.