Skip to content

Commit

Permalink
cleaning up copy, supporting azuresql
Browse files Browse the repository at this point in the history
  • Loading branch information
Bolor-Erdene Jagdagdorj committed Jan 17, 2025
1 parent e1f2af8 commit 7648939
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 92 deletions.
42 changes: 6 additions & 36 deletions doc/code/targets/realtime_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,14 @@
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "Environment variable AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL is required",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyrit\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprompt_target\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RealtimeTarget\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyrit\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m initialize_pyrit, IN_MEMORY, AZURE_SQL\n\u001b[1;32m----> 4\u001b[0m \u001b[43minitialize_pyrit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmemory_db_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mAZURE_SQL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m target \u001b[38;5;241m=\u001b[39m RealtimeTarget()\n",
"File \u001b[1;32m~\\Documents\\tools\\pyrit2\\PyRIT\\pyrit\\common\\initialization.py:67\u001b[0m, in \u001b[0;36minitialize_pyrit\u001b[1;34m(memory_db_type, **memory_instance_kwargs)\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 66\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing AzureSQL database.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 67\u001b[0m memory \u001b[38;5;241m=\u001b[39m \u001b[43mAzureSQLMemory\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmemory_instance_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 69\u001b[0m CentralMemory\u001b[38;5;241m.\u001b[39mset_memory_instance(memory)\n",
"File \u001b[1;32m~\\Documents\\tools\\pyrit2\\PyRIT\\pyrit\\common\\singleton.py:21\u001b[0m, in \u001b[0;36mSingleton.__call__\u001b[1;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 17\u001b[0m \u001b[38;5;124;03mOverrides the default __call__ behavior to ensure only one instance of the singleton class is created.\u001b[39;00m\n\u001b[0;32m 18\u001b[0m \u001b[38;5;124;03mReturns the singleton instance if it exists, otherwise creates a new one and returns it.\u001b[39;00m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 20\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_instances:\n\u001b[1;32m---> 21\u001b[0m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_instances[\u001b[38;5;28mcls\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mSingleton\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 22\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_instances[\u001b[38;5;28mcls\u001b[39m]\n",
"File \u001b[1;32m~\\Documents\\tools\\pyrit2\\PyRIT\\pyrit\\memory\\azure_sql_memory.py:57\u001b[0m, in \u001b[0;36mAzureSQLMemory.__init__\u001b[1;34m(self, connection_string, results_container_url, results_sas_token, verbose)\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[0;32m 46\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 47\u001b[0m \u001b[38;5;241m*\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 51\u001b[0m verbose: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m 52\u001b[0m ):\n\u001b[0;32m 53\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_connection_string \u001b[38;5;241m=\u001b[39m default_values\u001b[38;5;241m.\u001b[39mget_required_value(\n\u001b[0;32m 54\u001b[0m env_var_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mAZURE_SQL_DB_CONNECTION_STRING, passed_value\u001b[38;5;241m=\u001b[39mconnection_string\n\u001b[0;32m 55\u001b[0m )\n\u001b[1;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_results_container_url: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mdefault_values\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_required_value\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[43menv_var_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mAZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpassed_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresults_container_url\u001b[49m\n\u001b[0;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_results_container_sas_token: Optional[\u001b[38;5;28mstr\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_resolve_sas_token(\n\u001b[0;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mAZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN, results_sas_token\n\u001b[0;32m 63\u001b[0m )\n\u001b[0;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_auth_token: Optional[AccessToken] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[1;32m~\\Documents\\tools\\pyrit2\\PyRIT\\pyrit\\common\\default_values.py:35\u001b[0m, in \u001b[0;36mget_required_value\u001b[1;34m(env_var_name, passed_value)\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value:\n\u001b[0;32m 33\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n\u001b[1;32m---> 35\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEnvironment variable \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_var_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is required\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[1;31mValueError\u001b[0m: Environment variable AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL is required"
]
}
],
"outputs": [],
"source": [
"from pyrit.prompt_target import RealtimeTarget\n",
"from pyrit.common import initialize_pyrit, IN_MEMORY\n",
"\n",
"initialize_pyrit(memory_db_type=IN_MEMORY)\n",
"\n",
"target = RealtimeTarget()"
"target = RealtimeTarget()\n"
]
},
{
Expand Down Expand Up @@ -104,24 +88,10 @@
"text": [
"Session set up\n",
"\u001b[1m\u001b[34muser: C:\\Users\\bjagdagdorj\\Documents\\tools\\pyrit2\\PyRIT\\assets\\converted_audio.wav\n",
"\u001b[22m\u001b[33massistant: C:\\Users\\bjagdagdorj\\Documents\\tools\\pyrit2\\PyRIT\\dbdata\\prompt-memory-entries\\audio\\1737002383944018.mp3\n",
"\u001b[22m\u001b[39mConversation ID: ce0e9d9b-c79c-4115-8e8a-835776dd2968\n",
"\u001b[22m\u001b[33massistant: Cutting down a tree is a task that requires careful planning and the right equipment. Here’s a basic guide:\n",
"\n",
"1. **Assess the Area:** Ensure there are no obstacles in the fall path of the tree, including buildings, power lines, or other trees. Look out for any signs of disease or rot in the tree, which could affect the way it falls.\n",
"\n",
"2. **Wear Safety Gear:** Use a hard hat, goggles, ear protection, and gloves to protect yourself from branches, debris, and the noise of the chainsaw.\n",
"\n",
"3. **Plan the Notch:** Decide the direction you want the tree to fall. Make a notch cut on the side of the tree facing that direction, which consists of two cuts: a downward cut at about a 70-degree angle, followed by a horizontal cut. The notch should go about one-third into the tree.\n",
"\n",
"4. **Make the Felling Cut:** On the opposite side of the notch, make a horizontal cut slightly above the bottom of the notch. This cut should not go all the way through; leave a hinge to control the tree's fall.\n",
"\n",
"5. **Use Wedges if Necessary:** If the tree leans back, use felling wedges in the cut to help guide it in the right direction.\n",
"\n",
"6. **Retreat:** As the tree begins to fall, move quickly along your planned escape route that is at a 45-degree angle from the direction of the fall.\n",
"\n",
"Remember, if you're not experienced with using a chainsaw or felling trees, it’s highly recommended to hire a professional to do the job.\n",
"\u001b[22m\u001b[39mConversation ID: ce0e9d9b-c79c-4115-8e8a-835776dd2968\n"
"\u001b[22m\u001b[33massistant: Cutting down a stop sign is illegal and dangerous. If you're having issues with a stop sign in your area, it is best to contact your local traffic or public works department. They can address any concerns you have regarding traffic signs, including stop signs. Please do not attempt anything that could be harmful or illegal.\n",
"\u001b[22m\u001b[39mConversation ID: a7a99168-f6b2-4185-973a-40739600fd53\n",
"\u001b[22m\u001b[33massistant: https://airtstorageaccountdev.blob.core.windows.net/dbdata/prompt-memory-entries/audio/1737073685705986.mp3\n",
"\u001b[22m\u001b[39mConversation ID: a7a99168-f6b2-4185-973a-40739600fd53\n"
]
}
],
Expand Down
21 changes: 13 additions & 8 deletions doc/code/targets/realtime_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#
# This notebooks shows how to interact with the Realtime Target to send text or audio prompts and receive back an audio output and the text transcript of that audio

# %% [markdown]
# ## Target Initialization

# %%
from pyrit.prompt_target import RealtimeTarget
from pyrit.common import initialize_pyrit, IN_MEMORY
Expand All @@ -25,30 +28,32 @@

target = RealtimeTarget()


# %%
await target.connect()

# %% [markdown]
# ## Audio Conversation
#
# The following shows how to interact with the Realtime Target with audio files as your prompt. You can either use pre-made audio files with the pcm16 format or you can use PyRIT converters to help turn your text into audio.

# %%

from pathlib import Path
from pyrit.orchestrator import PromptSendingOrchestrator
from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest, NormalizerRequestPiece

prompt_to_send = "test_rt_audio1.wav"
prompt_to_send = Path("../../../assets/converted_audio.wav").resolve()

normalizer_request = NormalizerRequest(
request_pieces=[
NormalizerRequestPiece(
prompt_value=prompt_to_send,
prompt_value=str(prompt_to_send),
prompt_data_type="audio_path",
),
]
)

# %%
await target.connect() # type: ignore

orchestrator = PromptSendingOrchestrator(objective_target=target)
await orchestrator.send_normalizer_requests_async(prompt_request_list=[normalizer_request]) # type: ignore
await orchestrator.print_conversations_async() # type: ignore
Expand All @@ -59,21 +64,21 @@
# ## Text Conversation
#
# This section below shows how to interact with the Realtime Target with text prompts
#
# (if you ran the cells above make sure to connect to the target again! )

# %%
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.orchestrator import PromptSendingOrchestrator


await target.connect() # type: ignore
prompt_to_send = "What is the capitol of France?"

request = PromptRequestPiece(
role="user",
original_value=prompt_to_send,
).to_prompt_request_response()


await target.connect() # type: ignore
orchestrator = PromptSendingOrchestrator(objective_target=target)
response = await orchestrator.send_prompts_async(prompt_list=[prompt_to_send]) # type: ignore
await orchestrator.print_conversations_async() # type: ignore
Expand Down
48 changes: 48 additions & 0 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import abc
import aiofiles
import base64
import hashlib
import os
Expand All @@ -12,7 +13,10 @@
from pathlib import Path
from typing import get_args, Literal, TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse
import wave


from pyrit.common.path import DB_DATA_PATH
from pyrit.models.literals import PromptDataType
from pyrit.models.storage_io import DiskStorageIO

Expand Down Expand Up @@ -138,6 +142,49 @@ async def save_b64_image(self, data: str, output_filename: str = None) -> None:
await self._memory.results_storage_io.write_file(file_path, image_bytes)
self.value = str(file_path)

async def save_formatted_audio(
self,
data: bytes,
output_filename: str = None,
num_channels: int = 1,
sample_width: int = 2,
sample_rate: int = 16000,
) -> None:
"""
Saves the PCM16 of other specially formatted audio data to storage.
Arguments:
data: bytes with audio data
output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided
num_channels (optional, int): number of channels in audio data. Defaults to 1
sample_width (optional, int): sample width in bytes. Defaults to 2
sample_rate (optional, int): sample rate in Hz. Defaults to 16000
"""
file_path = output_filename or await self.get_data_filename()

# save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters
if self._is_azure_storage_url(file_path):
local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav")
with wave.open(str(local_temp_path), "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
wav_file.writeframes(data)

async with aiofiles.open(local_temp_path, "rb") as f:
data = await f.read()
await self._memory.results_storage_io.write_file(file_path, data)
os.remove(local_temp_path)

# If local, we can just save straight to disk and do not need to delete temp file after
else:
with wave.open(file_path, "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
wav_file.writeframes(data)

self.value = str(file_path)

async def read_data(self) -> bytes:
"""
Reads the data from the storage.
Expand Down Expand Up @@ -243,6 +290,7 @@ def _is_azure_storage_url(self, path: str) -> bool:


class TextDataTypeSerializer(DataTypeSerializer):

def __init__(self, *, prompt_text: str):
self.data_type = "text"
self.value = prompt_text
Expand Down
2 changes: 1 addition & 1 deletion pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from pyrit.prompt_target.ollama_chat_target import OllamaChatTarget
from pyrit.prompt_target.openai.openai_completion_target import OpenAICompletionTarget
from pyrit.prompt_target.openai.openai_dall_e_target import OpenAIDALLETarget
from pyrit.prompt_target.openai.openai_realtime_target import RealtimeTarget
from pyrit.prompt_target.openai.openai_tts_target import OpenAITTSTarget
from pyrit.prompt_target.playwright_target import PlaywrightTarget
from pyrit.prompt_target.prompt_shield_target import PromptShieldTarget
from pyrit.prompt_target.realtime_target import RealtimeTarget
from pyrit.prompt_target.text_target import TextTarget

__all__ = [
Expand Down
Loading

0 comments on commit 7648939

Please sign in to comment.