diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index 5bf5a462..e3f1e1b4 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -252,6 +252,7 @@ def __init__( self._verbose = verbose self._client_api_version: Literal[0, 1] = client_api_version self._shutdown_event = threading.Event() + self._ws_server: websockets.WebSocketServer | None = None self._client_state_from_id: dict[int, _ClientHandleState] = {} @@ -266,7 +267,7 @@ def start(self) -> None: daemon=True, ).start() - # Wait for the thread to set self._event_loop and self._broadcast_buffer... + # Wait for ready signal from the background thread. ready_sem.acquire() # Broadcast buffer should be populated by the background worker. @@ -274,8 +275,10 @@ def start(self) -> None: def stop(self) -> None: """Stop the server.""" + assert self._ws_server is not None + self._ws_server.close() + self._ws_server = None self._thread_executor.shutdown(wait=True) - self._event_loop.stop() def on_client_connect(self, cb: Callable[[WebsockClientConnection], Any]) -> None: """Attach a callback to run for newly connected clients.""" @@ -316,7 +319,6 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None: # Need to make a new event loop for notebook compatbility. event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) - self._event_loop = event_loop self._broadcast_buffer = AsyncMessageBuffer( event_loop, persistent_messages=True ) @@ -462,28 +464,35 @@ async def viser_http_server( # Try to read + send over file. return (http.HTTPStatus.OK, response_headers, response_payload) - for _ in range(500): + for _ in range(1000): try: - event_loop.run_until_complete( - websockets.server.serve( - serve, - host, - port, - # Compression can be turned off to reduce client-side CPU usage. - # compression=None, - process_request=( - viser_http_server if http_server_root is not None else None - ), - ) + serve_future = websockets.server.serve( + serve, + host, + port, + # Compression can be turned off to reduce client-side CPU usage. + # compression=None, + process_request=( + viser_http_server if http_server_root is not None else None + ), ) + self._ws_server = serve_future.ws_server + event_loop.run_until_complete(serve_future) break except OSError: # Port not available. port += 1 continue + if self._ws_server is None: + raise RuntimeError("Failed to bind to port!") + self._port = port + ready_sem.release() event_loop.run_forever() + + # This will run only when the event loop ends, which happens when the + # websocket server is closed. rich.print("[bold](viser)[/bold] Server stopped")