Skip to content

Commit

Permalink
Make client accept a function for websocket uri and hadnshake metadata (
Browse files Browse the repository at this point in the history
#62)

Why
===

We're seeing stale tokens as part of the connection to pid2, we have no
way of regenerating the token

What changed
============

Make client accept a function for websocket uri and hadnshake metadata

Test plan
=========

Updated tests, unfortunately no comprehensive suite here so we'll test
this in our internal repo
  • Loading branch information
masad-frost authored Aug 9, 2024
1 parent ba7c087 commit 71ba3df
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
13 changes: 7 additions & 6 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any, Callable, Optional, Union
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Any, Optional, Union

from replit_river.client_transport import ClientTransport
from replit_river.transport_options import TransportOptions
Expand All @@ -16,22 +16,23 @@


class Client:

def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[Any] = None,
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
) -> None:
self._client_id = client_id
self._server_id = server_id
self._transport = ClientTransport(
websocket_uri=websocket_uri,
websocket_uri_factory=websocket_uri_factory,
client_id=client_id,
server_id=server_id,
transport_options=transport_options,
handshake_metadata=handshake_metadata,
handshake_metadata_factory=handshake_metadata_factory,
)

async def close(self) -> None:
Expand Down
20 changes: 14 additions & 6 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Optional, Tuple

import websockets
Expand Down Expand Up @@ -42,26 +43,27 @@


class ClientTransport(Transport):

def __init__(
self,
websocket_uri: str,
websocket_uri_factory: Callable[[], Awaitable[str]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
handshake_metadata: Optional[Any] = None,
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
):
super().__init__(
transport_id=client_id,
transport_options=transport_options,
is_server=False,
)
self._websocket_uri = websocket_uri
self._websocket_uri_factory = websocket_uri_factory
self._client_id = client_id
self._server_id = server_id
self._rate_limiter = LeakyBucketRateLimit(
transport_options.connection_retry_options
)
self._handshake_metadata = handshake_metadata
self._handshake_metadata_factory = handshake_metadata_factory
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

Expand Down Expand Up @@ -107,12 +109,18 @@ async def _establish_new_connection(
break
rate_limit.consume_budget(client_id)
try:
ws = await websockets.connect(self._websocket_uri)
websocket_uri = await self._websocket_uri_factory()
ws = await websockets.connect(websocket_uri)
session_id = (
self.generate_session_id()
if not old_session
else old_session.session_id
)

handshake_metadata = None
if self._handshake_metadata_factory is not None:
handshake_metadata = await self._handshake_metadata_factory()

try:
(
handshake_request,
Expand All @@ -121,7 +129,7 @@ async def _establish_new_connection(
self._transport_id,
self._server_id,
session_id,
self._handshake_metadata,
handshake_metadata,
ws,
old_session,
)
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,14 @@ async def client(
transport_options: TransportOptions,
no_logging_error: NoErrors,
) -> AsyncGenerator[Client, None]:

async def websocket_uri_factory() -> str:
return "ws://localhost:8765"

try:
async with serve(server.serve, "localhost", 8765):
client = Client(
"ws://localhost:8765",
websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
Expand Down

0 comments on commit 71ba3df

Please sign in to comment.