diff --git a/replit_river/client.py b/replit_river/client.py index 667bcd7..5ae4590 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable, AsyncIterator -from typing import Any, Callable, Optional, Union +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from typing import Any, Optional, Union from replit_river.client_transport import ClientTransport from replit_river.transport_options import TransportOptions @@ -16,22 +16,23 @@ class Client: + def __init__( self, - websocket_uri: str, + websocket_uri_factory: Callable[[], Awaitable[str]], client_id: str, server_id: str, transport_options: TransportOptions, - handshake_metadata: Optional[Any] = None, + handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None, ) -> None: self._client_id = client_id self._server_id = server_id self._transport = ClientTransport( - websocket_uri=websocket_uri, + websocket_uri_factory=websocket_uri_factory, client_id=client_id, server_id=server_id, transport_options=transport_options, - handshake_metadata=handshake_metadata, + handshake_metadata_factory=handshake_metadata_factory, ) async def close(self) -> None: diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py index a3b3eb1..a888b14 100644 --- a/replit_river/client_transport.py +++ b/replit_river/client_transport.py @@ -1,5 +1,6 @@ import asyncio import logging +from collections.abc import Awaitable, Callable from typing import Any, Optional, Tuple import websockets @@ -42,26 +43,27 @@ class ClientTransport(Transport): + def __init__( self, - websocket_uri: str, + websocket_uri_factory: Callable[[], Awaitable[str]], client_id: str, server_id: str, transport_options: TransportOptions, - handshake_metadata: Optional[Any] = None, + handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None, ): super().__init__( transport_id=client_id, transport_options=transport_options, is_server=False, ) - self._websocket_uri = websocket_uri + self._websocket_uri_factory = websocket_uri_factory self._client_id = client_id self._server_id = server_id self._rate_limiter = LeakyBucketRateLimit( transport_options.connection_retry_options ) - self._handshake_metadata = handshake_metadata + self._handshake_metadata_factory = handshake_metadata_factory # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() @@ -107,12 +109,18 @@ async def _establish_new_connection( break rate_limit.consume_budget(client_id) try: - ws = await websockets.connect(self._websocket_uri) + websocket_uri = await self._websocket_uri_factory() + ws = await websockets.connect(websocket_uri) session_id = ( self.generate_session_id() if not old_session else old_session.session_id ) + + handshake_metadata = None + if self._handshake_metadata_factory is not None: + handshake_metadata = await self._handshake_metadata_factory() + try: ( handshake_request, @@ -121,7 +129,7 @@ async def _establish_new_connection( self._transport_id, self._server_id, session_id, - self._handshake_metadata, + handshake_metadata, ws, old_session, ) diff --git a/tests/conftest.py b/tests/conftest.py index f413770..939eb39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,10 +137,14 @@ async def client( transport_options: TransportOptions, no_logging_error: NoErrors, ) -> AsyncGenerator[Client, None]: + + async def websocket_uri_factory() -> str: + return "ws://localhost:8765" + try: async with serve(server.serve, "localhost", 8765): client = Client( - "ws://localhost:8765", + websocket_uri_factory, client_id="test_client", server_id="test_server", transport_options=transport_options,