diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index c80e519c6..31bdd3f76 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -75,6 +75,7 @@ async def connect(self): url = f"{self._endpoint}?{urlencode(query_params)}" self.websocket = await websockets.connect(url) + # self._existing_conversation = {conversation_id: self.websocket} logger.info("Successfully connected to AzureOpenAI Realtime API") def _set_system_prompt_and_config_vars(self): @@ -122,8 +123,15 @@ async def send_config(self): async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: # Sends a prompt to the target and returns the response. + convo_id = prompt_request.request_pieces[0].conversation_id if not self._existing_conversation: await self.connect() + self._existing_conversation = {prompt_request.request_pieces[0].conversation_id: self.websocket} + elif convo_id not in self._existing_conversation: + await self.connect() + self._existing_conversation[prompt_request.request_pieces[0].conversation_id] = self.websocket + else: + self.websocket = self._existing_conversation[prompt_request.request_pieces[0].conversation_id] # Validation function self._validate_request(prompt_request=prompt_request) @@ -204,10 +212,16 @@ async def cleanup_target(self): """ Disconnects from the WebSocket server to clean up """ - if self.websocket: - await self.websocket.close() - self.websocket = None - logger.info(f"Disconnected from {self._endpoint}") + for conversation_id, websocket in self._existing_conversation.items(): + if websocket: + await websocket.close() + logger.info(f"Disconnected from {self._endpoint}") + self._existing_conversation = {} + self.websocket = None + # if self.websocket: + # await self.websocket.close() + # self.websocket = None + # logger.info(f"Disconnected from {self._endpoint}") async def send_response_create(self): """ @@ -223,7 +237,7 @@ async def receive_events(self, convo_max) -> list: """ - if self.websocket is None: + if self.websocket is None: # change this to existing_conversation.websocket logger.error("WebSocket connection is not established") raise Exception("WebSocket connection is not established")