From 1d24cf2efad3538f824006beb3205a8b56d4113a Mon Sep 17 00:00:00 2001 From: Bolor-Erdene Jagdagdorj Date: Tue, 14 Jan 2025 13:37:20 -0800 Subject: [PATCH] reverting prompt normalizer and changing realtime target to use one response instead of list --- doc/code/targets/realtime_target.ipynb | 162 +++++++++--------- doc/code/targets/realtime_target.py | 31 +++- .../multi_turn/red_teaming_orchestrator.py | 43 +++-- .../multi_turn/tree_of_attacks_node.py | 41 ++--- .../question_answer_benchmark_orchestrator.py | 10 +- .../orchestrator/skeleton_key_orchestrator.py | 2 +- pyrit/prompt_normalizer/prompt_normalizer.py | 70 ++++---- pyrit/prompt_target/realtime_target.py | 71 ++++++-- 8 files changed, 236 insertions(+), 194 deletions(-) diff --git a/doc/code/targets/realtime_target.ipynb b/doc/code/targets/realtime_target.ipynb index d94f16eb9..c701085ed 100644 --- a/doc/code/targets/realtime_target.ipynb +++ b/doc/code/targets/realtime_target.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "0", "metadata": {}, "source": [ "# REALTIME TARGET" @@ -9,6 +10,7 @@ }, { "cell_type": "markdown", + "id": "1", "metadata": {}, "source": [ "## Using PyRIT" @@ -16,7 +18,8 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -30,6 +33,7 @@ }, { "cell_type": "markdown", + "id": "3", "metadata": {}, "source": [ "## Single Turn Audio Conversation" @@ -37,7 +41,8 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -45,8 +50,6 @@ "from pyrit.orchestrator import PromptSendingOrchestrator\n", "from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest, NormalizerRequestPiece\n", "\n", - "# text_prompt_to_send = \"Hi what is 2+2?\"\n", - "\n", "prompt_to_send = \"test_rt_audio1.wav\"\n", "\n", "normalizer_request = NormalizerRequest(\n", @@ -61,38 +64,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, + "id": "5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Session set up\n", - "\u001b[1m\u001b[34muser: test_rt_audio1.wav\n", - "\u001b[22m\u001b[33massistant: Making rice wine, also known as sake, involves fermenting rice with water, koji (a type of mold), and yeast. Here is a simplified overview:\n", - "\n", - "1. **Washing and Soaking Rice**: Wash the rice thoroughly to remove excess starch. Then, soak the rice in water for a few hours.\n", - "\n", - "2. **Steaming Rice**: Steam the soaked rice until it's cooked, which will make it easier to ferment.\n", - "\n", - "3. **Preparing Koji**: Koji mold is cultivated on steamed rice to convert the starches in the rice to sugars. This process usually takes a couple of days.\n", - "\n", - "4. **Fermentation**: Mix the steamed rice, koji rice, yeast, and water in a fermentation tank. Over the next few weeks, the mixture will ferment, converting the sugars to alcohol.\n", - "\n", - "5. **Pressing and Filtration**: After fermentation, the mixture is pressed to separate the liquid from the solid rice residue. The liquid is then filtered to remove any remaining particles.\n", - "\n", - "6. **Pasteurization**: The filtered sake is often pasteurized to kill any remaining yeast and enzymes, stabilizing the flavor.\n", - "\n", - "7. **Aging and Bottling**: The sake is aged for a few months to develop its flavor before being bottled.\n", - "\n", - "This is a traditional process, and there are many variations and complexities involved in professional sake brewing. If you're interested in making it at home, consider starting with a homebrewing kit or recipe to guide you through the process.\n", - "\u001b[22m\u001b[39mConversation ID: c75dafcc-9ab0-4a2e-b52c-d0513da7a391\n", - "\u001b[22m\u001b[33massistant: response_audio.wav\n", - "\u001b[22m\u001b[39mConversation ID: c75dafcc-9ab0-4a2e-b52c-d0513da7a391\n" - ] - } - ], + "outputs": [], "source": [ "await target.connect()\n", "\n", @@ -105,15 +80,52 @@ }, { "cell_type": "markdown", + "id": "6", "metadata": {}, "source": [ - "## Multiturn Text Conversation" + "## Single Turn Text Conversation" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "from pyrit.models.prompt_request_piece import PromptRequestPiece\n", + "from pyrit.orchestrator import PromptSendingOrchestrator\n", + "\n", + "\n", + "await target.connect()\n", + "prompt_to_send = \"Give me an image of a raccoon pirate as a Spanish baker in Spain\"\n", + "\n", + "request = PromptRequestPiece(\n", + " role=\"user\",\n", + " original_value=prompt_to_send,\n", + ").to_prompt_request_response()\n", + "\n", + "\n", + "orchestrator = PromptSendingOrchestrator(objective_target=target)\n", + "response = await orchestrator.send_prompts_async(prompt_list=[prompt_to_send]) # type: ignore\n", + "await orchestrator.print_conversations_async() # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "8", "metadata": {}, + "source": [ + "## Multiturn Text Conversation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "\n", @@ -141,27 +153,15 @@ " prompt_data_type=\"text\",\n", " )\n", " ]\n", - ")\n" + ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "id": "10", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Session set up\n", - "\u001b[1m\u001b[34muser: Hi what is 2+2?\n", - "\u001b[22m\u001b[33massistant: 2 + 2 equals 4.\n", - "\u001b[22m\u001b[39mConversation ID: 87b580f3-3373-49fd-98ce-ac4f18f38838\n", - "\u001b[22m\u001b[33massistant: response_audio.wav\n", - "\u001b[22m\u001b[39mConversation ID: 87b580f3-3373-49fd-98ce-ac4f18f38838\n" - ] - } - ], + "outputs": [], "source": [ "orchestrator = PromptSendingOrchestrator(objective_target=target)\n", "\n", @@ -172,22 +172,18 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Session set up\n", - "\u001b[1m\u001b[34muser: Now add 2?\n", - "\u001b[22m\u001b[33massistant: 4 + 2 equals 6.\n", - "\u001b[22m\u001b[39mConversation ID: 5104182c-2392-497a-84ec-0ddb577c21b4\n", - "\u001b[22m\u001b[33massistant: response_audio.wav\n", - "\u001b[22m\u001b[39mConversation ID: 5104182c-2392-497a-84ec-0ddb577c21b4\n" - ] - } - ], + "outputs": [], "source": [ "orchestrator = PromptSendingOrchestrator(objective_target=target)\n", "\n", @@ -198,12 +194,26 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, + "id": "13", "metadata": {}, "outputs": [], "source": [ "await target.disconnect() # type: ignore" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "from pyrit.memory import CentralMemory\n", + "\n", + "memory = CentralMemory.get_memory_instance()\n", + "memory.dispose_engine()" + ] } ], "metadata": { @@ -211,20 +221,8 @@ "display_name": "pyrit2", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 5 } diff --git a/doc/code/targets/realtime_target.py b/doc/code/targets/realtime_target.py index 3372829e0..cb77f1d18 100644 --- a/doc/code/targets/realtime_target.py +++ b/doc/code/targets/realtime_target.py @@ -34,8 +34,6 @@ from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest, NormalizerRequestPiece -# text_prompt_to_send = "Hi what is 2+2?" - prompt_to_send = "test_rt_audio1.wav" normalizer_request = NormalizerRequest( @@ -56,6 +54,27 @@ await orchestrator.print_conversations_async() # type: ignore +# %% [markdown] +# ## Single Turn Text Conversation + +# %% +from pyrit.models.prompt_request_piece import PromptRequestPiece +from pyrit.orchestrator import PromptSendingOrchestrator + + +await target.connect() +prompt_to_send = "Give me an image of a raccoon pirate as a Spanish baker in Spain" + +request = PromptRequestPiece( + role="user", + original_value=prompt_to_send, +).to_prompt_request_response() + + +orchestrator = PromptSendingOrchestrator(objective_target=target) +response = await orchestrator.send_prompts_async(prompt_list=[prompt_to_send]) # type: ignore +await orchestrator.print_conversations_async() # type: ignore + # %% [markdown] # ## Multiturn Text Conversation @@ -95,6 +114,8 @@ await orchestrator.print_conversations_async() # type: ignore +# %% + # %% orchestrator = PromptSendingOrchestrator(objective_target=target) @@ -104,3 +125,9 @@ # %% await target.disconnect() # type: ignore + +# %% +from pyrit.memory import CentralMemory + +memory = CentralMemory.get_memory_instance() +memory.dispose_engine() diff --git a/pyrit/orchestrator/multi_turn/red_teaming_orchestrator.py b/pyrit/orchestrator/multi_turn/red_teaming_orchestrator.py index 6774c14a9..ae020c548 100644 --- a/pyrit/orchestrator/multi_turn/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/multi_turn/red_teaming_orchestrator.py @@ -10,7 +10,6 @@ from pyrit.common.utils import combine_dict from pyrit.common.path import RED_TEAM_ORCHESTRATOR_PATH from pyrit.models import PromptRequestPiece, Score -from pyrit.models.prompt_request_response import PromptRequestResponse from pyrit.orchestrator import MultiTurnAttackResult, MultiTurnOrchestrator from pyrit.prompt_converter import PromptConverter from pyrit.prompt_normalizer import NormalizerRequest, NormalizerRequestPiece, PromptNormalizer @@ -253,18 +252,16 @@ async def _retrieve_and_send_prompt_async( conversation_id=objective_target_conversation_id, ) - response_piece = await self._prompt_normalizer.send_prompt_async( - normalizer_request=normalizer_request, - target=self._objective_target, - labels=memory_labels, - orchestrator_identifier=self.get_identifier(), - ) - - if isinstance(response_piece, PromptRequestResponse): - return response_piece.request_pieces[0] + response_piece = ( + await self._prompt_normalizer.send_prompt_async( + normalizer_request=normalizer_request, + target=self._objective_target, + labels=memory_labels, + orchestrator_identifier=self.get_identifier(), + ) + ).request_pieces[0] - else: - return response_piece[0].request_pieces[0] + return response_piece async def _check_conversation_complete_async(self, objective_target_conversation_id: str) -> Union[Score, None]: """ @@ -408,19 +405,19 @@ async def _get_prompt_from_adversarial_chat( prompt_text=prompt_text, conversation_id=adversarial_chat_conversation_id ) - response_text_values = await self._prompt_normalizer.send_prompt_async( - normalizer_request=normalizer_request, - target=self._adversarial_chat, - orchestrator_identifier=self.get_identifier(), - labels=memory_labels, + response_text = ( + ( + await self._prompt_normalizer.send_prompt_async( + normalizer_request=normalizer_request, + target=self._adversarial_chat, + orchestrator_identifier=self.get_identifier(), + labels=memory_labels, + ) + ) + .request_pieces[0] + .converted_value ) - if isinstance(response_text_values, PromptRequestResponse): - response_text = response_text_values.request_pieces[0].converted_value - - else: - response_text = response_text_values[0].request_pieces[0].converted_value - return response_text def _get_last_objective_target_response(self, objective_target_conversation_id: str) -> PromptRequestPiece | None: diff --git a/pyrit/orchestrator/multi_turn/tree_of_attacks_node.py b/pyrit/orchestrator/multi_turn/tree_of_attacks_node.py index 4b40f1768..c8a18867f 100644 --- a/pyrit/orchestrator/multi_turn/tree_of_attacks_node.py +++ b/pyrit/orchestrator/multi_turn/tree_of_attacks_node.py @@ -11,7 +11,6 @@ from pyrit.exceptions import InvalidJsonException, pyrit_json_retry, remove_markdown_json from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import SeedPrompt -from pyrit.models.prompt_request_response import PromptRequestResponse from pyrit.prompt_converter import PromptConverter from pyrit.prompt_normalizer import NormalizerRequest, NormalizerRequestPiece, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget @@ -106,17 +105,14 @@ async def send_prompt_async(self, objective: str): conversation_id=self.objective_target_conversation_id, ) - normalizer_response = await self._prompt_normalizer.send_prompt_async( - normalizer_request=objective_target_request, - target=self._objective_target, - labels=self._global_memory_labels, - orchestrator_identifier=self._orchestrator_id, - ) - - if isinstance(normalizer_response, PromptRequestResponse): - response = normalizer_response.request_pieces[0] - else: - response = normalizer_response[0].request_pieces[0] + response = ( + await self._prompt_normalizer.send_prompt_async( + normalizer_request=objective_target_request, + target=self._objective_target, + labels=self._global_memory_labels, + orchestrator_identifier=self._orchestrator_id, + ) + ).request_pieces[0] logger.debug(f"saving score with prompt_request_response_id: {response.id}") @@ -216,18 +212,19 @@ async def _generate_red_teaming_prompt_async(self, objective) -> str: conversation_id=self.adversarial_chat_conversation_id, ) - adversarial_chat_normalizer_response = await self._prompt_normalizer.send_prompt_async( - normalizer_request=adversarial_chat_request, - target=self._adversarial_chat, - labels=self._global_memory_labels, - orchestrator_identifier=self._orchestrator_id, + adversarial_chat_response = ( + ( + await self._prompt_normalizer.send_prompt_async( + normalizer_request=adversarial_chat_request, + target=self._adversarial_chat, + labels=self._global_memory_labels, + orchestrator_identifier=self._orchestrator_id, + ) + ) + .request_pieces[0] + .converted_value ) - if isinstance(adversarial_chat_normalizer_response, PromptRequestResponse): - adversarial_chat_response = adversarial_chat_normalizer_response.request_pieces[0].converted_value - else: - adversarial_chat_response = adversarial_chat_normalizer_response[0].request_pieces[0].converted_value - return self._parse_red_teaming_response(adversarial_chat_response) def _parse_red_teaming_response(self, red_teaming_response: str) -> str: diff --git a/pyrit/orchestrator/question_answer_benchmark_orchestrator.py b/pyrit/orchestrator/question_answer_benchmark_orchestrator.py index 24dd345d1..67be41cd2 100644 --- a/pyrit/orchestrator/question_answer_benchmark_orchestrator.py +++ b/pyrit/orchestrator/question_answer_benchmark_orchestrator.py @@ -81,19 +81,15 @@ async def evaluate(self) -> None: prompt_text=question_prompt, conversation_id=self._conversation_id ) - responses = await self._normalizer.send_prompt_async( + response = await self._normalizer.send_prompt_async( normalizer_request=request, target=self._chat_model_under_evaluation, labels=self._global_memory_labels, orchestrator_identifier=self.get_identifier(), ) - if not isinstance(responses, list): - responses = [responses] - - for response in responses: - answer = response.request_pieces[0].converted_value - curr_score = self._scorer.score_question(question=question_entry, answer=answer) + answer = response.request_pieces[0].converted_value + curr_score = self._scorer.score_question(question=question_entry, answer=answer) if self._verbose: msg = textwrap.dedent( diff --git a/pyrit/orchestrator/skeleton_key_orchestrator.py b/pyrit/orchestrator/skeleton_key_orchestrator.py index 9bac06327..e5cbf2661 100644 --- a/pyrit/orchestrator/skeleton_key_orchestrator.py +++ b/pyrit/orchestrator/skeleton_key_orchestrator.py @@ -74,7 +74,7 @@ async def send_skeleton_key_with_prompt_async( self, *, prompt: str, - ) -> list[PromptRequestResponse]: + ) -> PromptRequestResponse: """ Sends a skeleton key, followed by the attack prompt to the target. diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index ca2a39e38..2f86050d3 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -32,7 +32,7 @@ async def send_prompt_async( sequence: int = -1, labels: Optional[dict[str, str]] = None, orchestrator_identifier: Optional[dict[str, str]] = None, - ) -> list[PromptRequestResponse]: + ) -> PromptRequestResponse: """ Sends a single request to a target. @@ -55,10 +55,10 @@ async def send_prompt_async( orchestrator_identifier=orchestrator_identifier, ) - responses: list[PromptRequestResponse] = [] + response = None try: - responses = await target.send_prompt_async(prompt_request=request) + response = await target.send_prompt_async(prompt_request=request) await self._calc_hash_and_add_request_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution @@ -70,7 +70,6 @@ async def send_prompt_async( response_type="text", error="empty", ) - responses = [response] except Exception as ex: # Ensure request to memory before processing exception @@ -86,15 +85,15 @@ async def send_prompt_async( await self._calc_hash_and_add_request_to_memory(request=error_response) raise - if responses is None: + if response is None: return None await self.convert_response_values( - response_converter_configurations=normalizer_request.response_converters, prompt_responses=responses + response_converter_configurations=normalizer_request.response_converters, prompt_response=response ) - await self._calc_hash_and_add_request_to_memory(request=responses) - return responses + await self._calc_hash_and_add_request_to_memory(request=response) + return response async def send_prompt_batch_to_target_async( self, @@ -136,44 +135,33 @@ async def send_prompt_batch_to_target_async( async def convert_response_values( self, response_converter_configurations: list[PromptResponseConverterConfiguration], - prompt_responses: list[PromptRequestResponse] | PromptRequestResponse, + prompt_response: PromptRequestResponse, ): - if not isinstance(prompt_responses, list): - prompt_responses = [prompt_responses] - - for prompt_response in prompt_responses: - for response_piece_index, response_piece in enumerate(prompt_response.request_pieces): - for converter_configuration in response_converter_configurations: - indexes = converter_configuration.indexes_to_apply - data_types = converter_configuration.prompt_data_types_to_apply - - if indexes and response_piece_index not in indexes: - continue - if data_types and response_piece.original_value_data_type not in data_types: - continue - - for converter in converter_configuration.converters: - converter_output = await converter.convert_async( - prompt=response_piece.original_value, input_type=response_piece.original_value_data_type - ) - response_piece.converted_value = converter_output.output_text - response_piece.converted_value_data_type = converter_output.output_type - - async def _calc_hash_and_add_request_to_memory( - self, request: PromptRequestResponse | list[PromptRequestResponse] - ) -> None: + + for response_piece_index, response_piece in enumerate(prompt_response.request_pieces): + for converter_configuration in response_converter_configurations: + indexes = converter_configuration.indexes_to_apply + data_types = converter_configuration.prompt_data_types_to_apply + + if indexes and response_piece_index not in indexes: + continue + if data_types and response_piece.original_value_data_type not in data_types: + continue + + for converter in converter_configuration.converters: + converter_output = await converter.convert_async( + prompt=response_piece.original_value, input_type=response_piece.original_value_data_type + ) + response_piece.converted_value = converter_output.output_text + response_piece.converted_value_data_type = converter_output.output_type + + async def _calc_hash_and_add_request_to_memory(self, request: PromptRequestResponse) -> None: """ Adds a request to the memory. """ - if not isinstance(request, list): - request = [request] - tasks = [ - asyncio.create_task(piece.set_sha256_values_async()) for req in request for piece in req.request_pieces - ] + tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.request_pieces] await asyncio.gather(*tasks) - - for req in request: - self._memory.add_request_response_to_memory(request=req) + self._memory.add_request_response_to_memory(request=request) async def _build_prompt_request_response( self, diff --git a/pyrit/prompt_target/realtime_target.py b/pyrit/prompt_target/realtime_target.py index 79492e517..dd1db23a0 100644 --- a/pyrit/prompt_target/realtime_target.py +++ b/pyrit/prompt_target/realtime_target.py @@ -5,22 +5,23 @@ import base64 import json import logging +from typing import Optional import wave import websockets from pyrit.common import default_values from pyrit.models import PromptRequestResponse -from pyrit.models.prompt_request_response import construct_response_from_request +from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.prompt_target import PromptTarget, limit_requests_per_minute logger = logging.getLogger(__name__) class RealtimeTarget(PromptTarget): - ENDPOINT_WS_URL_ENVIRONMENT_VARIABLE = "AZURE_OPENAI_REALTIME_API_WS_URL" - DEPLOYMENT_ENVIRONMENT_VARIABLE = "AZURE_OPENAI_REALTIME_DEPLOYMENT" - API_ENVIRONMENT_VARIABLE = "AZURE_OPENAI_REALTIME_API_KEY" - API_VERSION_ENVIRONMENT_VARIABLE = "AZURE_OPENAI_REALTIME_API_VERSION" + REALTIME_ENDPOINT_WEBSOCKET_URL = "AZURE_OPENAI_REALTIME_API_WS_URL" + REALTIME_DEPLOYMENT = "AZURE_OPENAI_REALTIME_DEPLOYMENT" + REALTIME_API_KEY = "AZURE_OPENAI_REALTIME_API_KEY" + REALTIME_API_VERSION = "AZURE_OPENAI_REALTIME_API_VERSION" def __init__( self, @@ -32,15 +33,15 @@ def __init__( **kwargs, ) -> None: - self.api_key = default_values.get_required_value(env_var_name=self.API_ENVIRONMENT_VARIABLE, passed_value=key) + self.api_key = default_values.get_required_value(env_var_name=self.REALTIME_API_KEY, passed_value=key) self.url = default_values.get_required_value( - env_var_name=self.ENDPOINT_WS_URL_ENVIRONMENT_VARIABLE, passed_value=url + env_var_name=self.REALTIME_ENDPOINT_WEBSOCKET_URL, passed_value=url ) self.deployment = default_values.get_required_value( - env_var_name=self.DEPLOYMENT_ENVIRONMENT_VARIABLE, passed_value=deployment + env_var_name=self.REALTIME_DEPLOYMENT, passed_value=deployment ) self.api_version = default_values.get_required_value( - env_var_name=self.API_VERSION_ENVIRONMENT_VARIABLE, passed_value=api_version + env_var_name=self.REALTIME_API_KEY, passed_value=api_version ) self.websocket = None @@ -56,6 +57,7 @@ async def connect(self): url = f"{self.url}/openai/realtime?api-version={self.api_version}" url = f"{url}&deployment={self.deployment}&api-key={self.api_key}" + print(url) self.websocket = await websockets.connect( url, extra_headers=headers, @@ -63,7 +65,7 @@ async def connect(self): logger.info("Successfully connected to AzureOpenAI Realtime API") @limit_requests_per_minute - async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> list[PromptRequestResponse]: + async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: # Validation function self._validate_request(prompt_request=prompt_request) @@ -103,16 +105,53 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> l events = await receive_tasks output_audio_path = await self.save_audio(events[0]) - # await self.disconnect() - response_entry = construct_response_from_request( - request=request, response_text_pieces=[events[1]], response_type="text" + text_response_piece = PromptRequestPiece( + original_value=events[1], + original_value_data_type="text", + # converted_value=events[1], + role="assistant", + converted_value_data_type="text", + ) + audio_response_piece = PromptRequestPiece( + original_value=output_audio_path, + original_value_data_type="audio_path", + converted_value=output_audio_path, + role="assistant", + converted_value_data_type="audio_path", ) - audio_response_entry = construct_response_from_request( - request=request, response_text_pieces=[output_audio_path], response_type="audio_path" + response_entry = self.construct_response_from_request( + request=request, response_pieces=[audio_response_piece, text_response_piece] ) - return [response_entry, audio_response_entry] # TODO: can make the transcription a flag to return or not + return response_entry + + def construct_response_from_request( + self, + request: PromptRequestPiece, + response_pieces: list[PromptRequestPiece], + prompt_metadata: Optional[str] = None, + ) -> PromptRequestResponse: + """ + Constructs a response entry from a request. + """ + return PromptRequestResponse( + request_pieces=[ + PromptRequestPiece( + role="assistant", + original_value=resp_piece.original_value, + conversation_id=request.conversation_id, + labels=request.labels, + prompt_target_identifier=request.prompt_target_identifier, + orchestrator_identifier=request.orchestrator_identifier, + original_value_data_type=resp_piece.original_value_data_type, + converted_value_data_type=resp_piece.converted_value_data_type, + converted_value=resp_piece.converted_value, + prompt_metadata=prompt_metadata, + ) + for resp_piece in response_pieces + ] + ) async def save_audio( self, audio_bytes: bytes, num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000