Skip to content

Commit

Permalink
adding mapping of websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
Bolor-Erdene Jagdagdorj committed Jan 23, 2025
1 parent 2cf9242 commit f084be1
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions pyrit/prompt_target/openai/openai_realtime_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def connect(self):
url = f"{self._endpoint}?{urlencode(query_params)}"

self.websocket = await websockets.connect(url)
# self._existing_conversation = {conversation_id: self.websocket}
logger.info("Successfully connected to AzureOpenAI Realtime API")

def _set_system_prompt_and_config_vars(self):
Expand Down Expand Up @@ -122,8 +123,15 @@ async def send_config(self):
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
# Sends a prompt to the target and returns the response.

convo_id = prompt_request.request_pieces[0].conversation_id
if not self._existing_conversation:
await self.connect()
self._existing_conversation = {prompt_request.request_pieces[0].conversation_id: self.websocket}
elif convo_id not in self._existing_conversation:
await self.connect()
self._existing_conversation[prompt_request.request_pieces[0].conversation_id] = self.websocket
else:
self.websocket = self._existing_conversation[prompt_request.request_pieces[0].conversation_id]

# Validation function
self._validate_request(prompt_request=prompt_request)
Expand Down Expand Up @@ -204,10 +212,16 @@ async def cleanup_target(self):
"""
Disconnects from the WebSocket server to clean up
"""
if self.websocket:
await self.websocket.close()
self.websocket = None
logger.info(f"Disconnected from {self._endpoint}")
for conversation_id, websocket in self._existing_conversation.items():
if websocket:
await websocket.close()
logger.info(f"Disconnected from {self._endpoint}")
self._existing_conversation = {}
self.websocket = None
# if self.websocket:
# await self.websocket.close()
# self.websocket = None
# logger.info(f"Disconnected from {self._endpoint}")

async def send_response_create(self):
"""
Expand All @@ -223,7 +237,7 @@ async def receive_events(self, convo_max) -> list:
"""

if self.websocket is None:
if self.websocket is None: # change this to existing_conversation.websocket
logger.error("WebSocket connection is not established")
raise Exception("WebSocket connection is not established")

Expand Down

0 comments on commit f084be1

Please sign in to comment.