diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index fcf2e1aa3..005f77c03 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -66,14 +66,14 @@ async def connect(self): logger.info(f"Connecting to WebSocket: {self._endpoint}") - websocket_url = f"{self._endpoint}/openai/realtime" + base_url = f"{self._endpoint}/openai/realtime" query_params = { "api-version": self._api_version, "deployment": self._deployment_name, "api-key": self._api_key, "OpenAI-Beta": "realtime=v1", } - url = f"{websocket_url}?{urlencode(query_params)}" + url = f"{base_url}?{urlencode(query_params)}" self.websocket = await websockets.connect(url) logger.info("Successfully connected to AzureOpenAI Realtime API") diff --git a/tests/unit/target/test_realtime_target.py b/tests/unit/target/test_realtime_target.py index 6872f82f3..1abc9b1c9 100644 --- a/tests/unit/target/test_realtime_target.py +++ b/tests/unit/target/test_realtime_target.py @@ -11,19 +11,18 @@ @pytest.fixture def target(duckdb_instance): - return RealtimeTarget( - api_key="test_key", endpoint="wss://test_url", deployment_name="test_deployment", api_version="v1" - ) + return RealtimeTarget(api_key="test_key", endpoint="wss://test_url", deployment_name="test", api_version="v1") @pytest.mark.asyncio async def test_connect_success(target): # Mock the websockets.connect method + url = ("wss://test_url/openai/realtime?api-version=v1&deployment=test&api-key=test_key&OpenAI-Beta=realtime%3Dv1",) + with patch("websockets.connect", new_callable=AsyncMock) as mock_connect: await target.connect() mock_connect.assert_called_once_with( - "wss://test_url/openai/realtime?api-version=v1&deployment=test_deployment&api-key=test_key", - extra_headers={"Authorization": "Bearer test_key", "OpenAI-Beta": "realtime=v1"}, + url, ) assert target.websocket == mock_connect.return_value