Skip to content

Commit

Permalink
refactor: update websockets 14.1
Browse files Browse the repository at this point in the history
  • Loading branch information
MVladislav committed Jan 11, 2025
1 parent 1a72fa8 commit 6ddfefb
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 109 deletions.
10 changes: 5 additions & 5 deletions amqtt/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from websockets import ConnectionClosed
from websockets.legacy.protocol import WebSocketCommonProtocol
from websockets.asyncio.connection import Connection


class ReaderAdapter:
Expand Down Expand Up @@ -54,10 +54,10 @@ async def close(self) -> None:
class WebSocketsReader(ReaderAdapter):
"""WebSockets API reader adapter.
This adapter relies on WebSocketCommonProtocol to read from a WebSocket.
This adapter relies on Connection to read from a WebSocket.
"""

def __init__(self, protocol: WebSocketCommonProtocol) -> None:
def __init__(self, protocol: Connection) -> None:
self._protocol = protocol
self._stream = io.BytesIO(b"")

Expand Down Expand Up @@ -88,10 +88,10 @@ async def _feed_buffer(self, n: int = 1) -> None:
class WebSocketsWriter(WriterAdapter):
"""WebSockets API writer adapter.
This adapter relies on WebSocketCommonProtocol to write to a WebSocket.
This adapter relies on Connection to write to a WebSocket.
"""

def __init__(self, protocol: WebSocketCommonProtocol) -> None:
def __init__(self, protocol: Connection) -> None:
self._protocol = protocol
self._stream = io.BytesIO(b"")

Expand Down
8 changes: 4 additions & 4 deletions amqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any

from transitions import Machine, MachineError
import websockets
from websockets.legacy.server import WebSocketServerProtocol
import websockets.asyncio.server
from websockets.asyncio.server import ServerConnection

from amqtt.adapters import (
ReaderAdapter,
Expand Down Expand Up @@ -257,7 +257,7 @@ async def start(self) -> None:
msg = "Invalid port value in bind value: {}".format(listener["bind"])
raise BrokerException(msg) from e

instance: asyncio.Server | websockets.WebSocketServer | None = None
instance: asyncio.Server | websockets.asyncio.server.Server | None = None
if listener["type"] == "tcp":
cb_partial = partial(self.stream_connected, listener_name=listener_name)
instance = await asyncio.start_server(
Expand Down Expand Up @@ -324,7 +324,7 @@ async def shutdown(self) -> None:
async def internal_message_broadcast(self, topic: str, data: bytes, qos: int | None = None) -> None:
return await self._broadcast_message(None, topic, data, qos)

async def ws_connected(self, websocket: WebSocketServerProtocol, uri: Any, listener_name: str) -> None: # noqa: ARG002
async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> None:
await self.client_connected(listener_name, WebSocketsReader(websocket), WebSocketsWriter(websocket))

async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None:
Expand Down
21 changes: 12 additions & 9 deletions amqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import wraps
import logging
import ssl
from typing import Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
from urllib.parse import urlparse, urlunparse

import websockets
Expand All @@ -26,6 +26,9 @@
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import gen_client_id

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection

_defaults: dict[str, Any] = {
"keep_alive": 10,
"ping_delay": 1,
Expand Down Expand Up @@ -104,7 +107,7 @@ def __init__(self, client_id: str | None = None, config: dict[str, Any] | None =
self._disconnect_task: asyncio.Task[Any] | None = None
self._connected_state = asyncio.Event()
self._no_more_connections = asyncio.Event()
self.extra_headers: dict[str, Any] | HeadersLike = {}
self.additional_headers: dict[str, Any] | HeadersLike = {}

# Init plugins manager
context = ClientContext()
Expand All @@ -119,7 +122,7 @@ async def connect(
cafile: str | None = None,
capath: str | None = None,
cadata: str | None = None,
extra_headers: dict[str, Any] | HeadersLike | None = None,
additional_headers: dict[str, Any] | HeadersLike | None = None,
) -> int:
"""Connect to a remote broker.
Expand All @@ -136,14 +139,14 @@ async def connect(
:param cafile: server certificate authority file (optional, used for secured connection)
:param capath: server certificate authority path (optional, used for secured connection)
:param cadata: server certificate authority data (optional, used for secured connection)
:param extra_headers: a dictionary with additional http headers that should be sent on the initial connection (optional,
used only with websocket connections)
:param additional_headers: a dictionary with additional http headers that should be sent on the initial connection
(optional, used only with websocket connections)
:return: `CONNACK <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033>`_ return code
:raise: :class:`amqtt.client.ConnectException` if connection fails
"""
extra_headers = extra_headers if extra_headers is not None else {}
additional_headers = additional_headers if additional_headers is not None else {}
self.session = self._init_session(uri, cleansession, cafile, capath, cadata)
self.extra_headers = extra_headers
self.additional_headers = additional_headers
self.logger.debug(f"Connecting to: {uri}")

try:
Expand Down Expand Up @@ -467,10 +470,10 @@ async def _connect_coro(self) -> int:
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)
elif scheme in ("ws", "wss") and self.session.broker_uri:
websocket = await websockets.connect(
websocket: ClientConnection = await websockets.connect(
self.session.broker_uri,
subprotocols=[websockets.Subprotocol("mqtt")],
extra_headers=self.extra_headers,
additional_headers=self.additional_headers,
**kwargs,
)
reader = WebSocketsReader(websocket)
Expand Down
2 changes: 1 addition & 1 deletion amqtt/scripts/pub_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None:
cafile=arguments["--ca-file"],
capath=arguments["--ca-path"],
cadata=arguments["--ca-data"],
extra_headers=_get_extra_headers(arguments),
additional_headers=_get_extra_headers(arguments),
)

qos = _get_qos(arguments)
Expand Down
2 changes: 1 addition & 1 deletion amqtt/scripts/sub_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def do_sub(client: MQTTClient, arguments: dict[str, Any]) -> None:
cafile=arguments["--ca-file"],
capath=arguments["--ca-path"],
cadata=arguments["--ca-data"],
extra_headers=_get_extra_headers(arguments),
additional_headers=_get_extra_headers(arguments),
)

qos = _get_qos(arguments)
Expand Down
11 changes: 2 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,11 @@ license = { text = "MIT" }
authors = [{ name = "aMQTT Contributors" }]

dependencies = [
# "transitions==0.8.0",
# "websockets>=9.0,<11.0",
# "passlib==1.7.0",
# "docopt==0.6.0",
# "PyYAML>=5.4.0,<7.0",
# # "coveralls==4.0.1",
"transitions==0.9.2",
"websockets==13.1",
# "websockets==14.1",
"websockets==14.1", # >=9.0,<11.0 # 14.1
"passlib==1.7.4",
"docopt==0.6.2",
"PyYAML==6.0.2",
"PyYAML==6.0.2", # >=5.4.0,<7.0
]

[dependency-groups]
Expand Down
Loading

0 comments on commit 6ddfefb

Please sign in to comment.