Skip to content

Commit

Permalink
Merge branch 'main' into users/rdheekonda/add-exception-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rdheekonda committed May 18, 2024
2 parents e31bdf7 + 02ed3e5 commit a58fa5b
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 199 deletions.
2 changes: 1 addition & 1 deletion doc/code/targets/azure_ml_chat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
").to_prompt_request_response()\n",
"\n",
"with AzureMLChatTarget() as azure_ml_chat_target:\n",
" print(azure_ml_chat_target.send_prompt(prompt_request=request))"
" print(await azure_ml_chat_target.send_prompt_async(prompt_request=request))"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion doc/code/targets/azure_ml_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
).to_prompt_request_response()

with AzureMLChatTarget() as azure_ml_chat_target:
print(azure_ml_chat_target.send_prompt(prompt_request=request))
print(await azure_ml_chat_target.send_prompt_async(prompt_request=request)) # type: ignore

# %% [markdown]
#
Expand Down
4 changes: 2 additions & 2 deletions doc/code/targets/prompt_targets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"# When `use_aad_auth=True`, ensure the user has 'Cognitive Service OpenAI User' role assigned on the AOAI Resource\n",
"# and `az login` is used to authenticate with the correct identity\n",
"with AzureOpenAIChatTarget(use_aad_auth=False) as azure_openai_chat_target:\n",
" print(azure_openai_chat_target.send_prompt(prompt_request=request))"
" print(await azure_openai_chat_target.send_prompt_async(prompt_request=request))"
]
},
{
Expand Down Expand Up @@ -162,7 +162,7 @@
" sas_token=os.environ.get(\"AZURE_STORAGE_ACCOUNT_SAS_TOKEN\"),\n",
") as abs_prompt_target:\n",
"\n",
" print(abs_prompt_target.send_prompt(prompt_request=request))"
" print(await abs_prompt_target.send_prompt_async(prompt_request=request))"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions doc/code/targets/prompt_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# When `use_aad_auth=True`, ensure the user has 'Cognitive Service OpenAI User' role assigned on the AOAI Resource
# and `az login` is used to authenticate with the correct identity
with AzureOpenAIChatTarget(use_aad_auth=False) as azure_openai_chat_target:
print(azure_openai_chat_target.send_prompt(prompt_request=request))
print(await azure_openai_chat_target.send_prompt_async(prompt_request=request)) # type: ignore

# %% [markdown]
# The `AzureBlobStorageTarget` inherits from `PromptTarget`, meaning it has functionality to send prompts. In contrast to `PromptChatTarget`s, `PromptTarget`s do not interact with chat assistants.
Expand Down Expand Up @@ -79,4 +79,4 @@
sas_token=os.environ.get("AZURE_STORAGE_ACCOUNT_SAS_TOKEN"),
) as abs_prompt_target:

print(abs_prompt_target.send_prompt(prompt_request=request))
print(await abs_prompt_target.send_prompt_async(prompt_request=request)) # type: ignore
6 changes: 3 additions & 3 deletions doc/how_to_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
" role=\"user\",\n",
" original_value=\"this is a test prompt\",\n",
" ).to_prompt_request_response()\n",
" target_llm.send_prompt(prompt_request=request)"
" await target_llm.send_prompt_async(prompt_request=request)"
]
},
{
Expand Down Expand Up @@ -164,7 +164,7 @@
"- run a single turn of the attack strategy or\n",
"- try to achieve the goal as specified in the attack strategy which may take multiple turns.\n",
"\n",
"The single turn is executed with the `send_prompt()` method. It generates the prompt using the red\n",
"The single turn is executed with the `send_prompt_async()` method. It generates the prompt using the red\n",
"teaming LLM and sends it to the target.\n",
"The full execution of the attack strategy over potentially multiple turns requires a mechanism\n",
"to determine if the goal has been achieved.\n",
Expand Down Expand Up @@ -569,7 +569,7 @@
" # or the maximum number of turns is reached.\n",
" red_teaming_orchestrator.apply_attack_strategy_until_completion(max_turns=5)\n",
"\n",
" # Alternatively, use send_prompt() to generate just a single turn of the attack strategy."
" # Alternatively, use send_prompt_async() to generate just a single turn of the attack strategy."
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions doc/how_to_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
role="user",
original_value="this is a test prompt",
).to_prompt_request_response()
target_llm.send_prompt(prompt_request=request)
await target_llm.send_prompt_async(prompt_request=request) # type: ignore

# %% [markdown]
# To expand to a wider variety of harms, it may be beneficial to write prompt templates instead of the
Expand Down Expand Up @@ -101,7 +101,7 @@
# - run a single turn of the attack strategy or
# - try to achieve the goal as specified in the attack strategy which may take multiple turns.
#
# The single turn is executed with the `send_prompt()` method. It generates the prompt using the red
# The single turn is executed with the `send_prompt_async()` method. It generates the prompt using the red
# teaming LLM and sends it to the target.
# The full execution of the attack strategy over potentially multiple turns requires a mechanism
# to determine if the goal has been achieved.
Expand Down Expand Up @@ -158,7 +158,7 @@
# or the maximum number of turns is reached.
red_teaming_orchestrator.apply_attack_strategy_until_completion(max_turns=5)

# Alternatively, use send_prompt() to generate just a single turn of the attack strategy.
# Alternatively, use send_prompt_async() to generate just a single turn of the attack strategy.

# %% [markdown]
# Going a step further, we can generalize the attack strategy into templates as mentioned in an earlier
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/azure_openai_completion_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import concurrent.futures
import logging

from openai import AsyncAzureOpenAI
from openai.types.completion import Completion

Expand All @@ -14,7 +15,6 @@
from pyrit.models.prompt_request_response import PromptRequestResponse
from pyrit.prompt_target.prompt_target import PromptTarget


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(

def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Sends a normalized prompt to the prompt target and adds the request and response to memory
Deprecated. Sends a normalized prompt to the prompt target and adds the request and response to memory
"""
pool = concurrent.futures.ThreadPoolExecutor()
return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result()
Expand Down
76 changes: 9 additions & 67 deletions pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import concurrent.futures
import logging

from pyrit.chat_message_normalizer import ChatMessageNormalizer, ChatMessageNop
from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer
from pyrit.common import default_values, net_utility
from pyrit.memory import MemoryInterface
from pyrit.models import PromptRequestResponse
from pyrit.models import ChatMessage
from pyrit.models import ChatMessage, PromptRequestResponse
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,32 +67,11 @@ def __init__(
self._repetition_penalty = repetition_penalty

def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]

messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)
messages.append(request.to_chat_message())

self._memory.add_request_response_to_memory(request=prompt_request)

logger.info(f"Sending the following prompt to the prompt target: {request}")

resp_text = self._complete_chat(
messages=messages,
temperature=self._temperature,
top_p=self._top_p,
repetition_penalty=self._repetition_penalty,
)

if not resp_text:
raise ValueError("The chat returned an empty response.")

logger.info(f'Received the following response from the prompt target "{resp_text}"')

response_entry = self._memory.add_response_entries_to_memory(request=request, response_text_pieces=[resp_text])

return response_entry
"""
Deprecated. Use send_prompt_async instead.
"""
pool = concurrent.futures.ThreadPoolExecutor()
return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result()

async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

Expand Down Expand Up @@ -121,44 +101,6 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P

return response_entry

def _complete_chat(
self,
messages: list[ChatMessage],
max_tokens: int = 400,
temperature: float = 1.0,
top_p: int = 1,
repetition_penalty: float = 1.2,
) -> str:
"""
Completes a chat interaction by generating a response to the given input prompt.
This is a synchronous wrapper for the asynchronous _generate_and_extract_response method.
Args:
messages (list[ChatMessage]): The chat messages objects containing the role and content.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 400.
temperature (float, optional): Controls randomness in the response generation.
Defaults to 1.0. 1 is more random, 0 is less.
top_p (int, optional): Controls diversity of the response generation. Defaults to 1.
1 is more random, 0 is less.
repetition_penalty (float, optional): Controls repetition in the response generation.
Defaults to 1.2.
Raises:
Exception: For any errors during the process.
Returns:
str: The generated response message.
"""
headers = self._get_headers()
payload = self._construct_http_body(messages, max_tokens, temperature, top_p, repetition_penalty)

response = net_utility.make_request_and_raise_if_error(
endpoint_uri=self.endpoint_uri, method="POST", request_body=payload, headers=headers
)

return response.json()["output"]

async def _complete_chat_async(
self,
messages: list[ChatMessage],
Expand Down
47 changes: 9 additions & 38 deletions pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) Adriano Maia <[email protected]>
# Licensed under the MIT license.
import asyncio
import concurrent.futures
import logging

from pyrit.chat_message_normalizer import ChatMessageNormalizer, ChatMessageNop
from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer
from pyrit.common import default_values, net_utility
from pyrit.memory import MemoryInterface
from pyrit.models import ChatMessage
from pyrit.models import PromptRequestPiece, PromptRequestResponse
from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,28 +37,11 @@ def __init__(
self.chat_message_normalizer = chat_message_normalizer

def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
self._validate_request(prompt_request=prompt_request)
request: PromptRequestPiece = prompt_request.request_pieces[0]

messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)
messages.append(request.to_chat_message())

logger.info(f"Sending the following prompt to the prompt target: {self} {request}")

self._memory.add_request_response_to_memory(request=prompt_request)

resp = self._complete_chat(
messages=messages,
)

if not resp:
raise ValueError("The chat returned an empty response.")

logger.info(f'Received the following response from the prompt target "{resp}"')

response_entry = self._memory.add_response_entries_to_memory(request=request, response_text_pieces=[resp])

return response_entry
"""
Deprecated. Use send_prompt_async instead.
"""
pool = concurrent.futures.ThreadPoolExecutor()
return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result()

async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

Expand All @@ -84,19 +68,6 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P

return response_entry

def _complete_chat(
self,
messages: list[ChatMessage],
) -> str:
headers = self._get_headers()
payload = self._construct_http_body(messages)

response = net_utility.make_request_and_raise_if_error(
endpoint_uri=self.endpoint_uri, method="POST", request_body=payload, headers=headers
)

return response.json()["message"]["content"]

async def _complete_chat_async(
self,
messages: list[ChatMessage],
Expand Down
9 changes: 4 additions & 5 deletions pyrit/prompt_target/prompt_chat_target/openai_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import abstractmethod
import logging
import json
import logging
from abc import abstractmethod
from typing import Optional

from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
Expand All @@ -12,8 +12,7 @@
from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential
from pyrit.common import default_values
from pyrit.memory import MemoryInterface
from pyrit.models import ChatMessage
from pyrit.models import PromptRequestResponse, PromptRequestPiece
from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +36,7 @@ def __init__(self) -> None:
pass

def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

self._validate_request(prompt_request=prompt_request)
request: PromptRequestPiece = prompt_request.request_pieces[0]

messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)
Expand Down
6 changes: 3 additions & 3 deletions pyrit/prompt_target/prompt_chat_target/prompt_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional
from pyrit.models import PromptRequestResponse, PromptRequestPiece
from pyrit.prompt_target import PromptTarget

from pyrit.memory import MemoryInterface
from pyrit.models import PromptRequestPiece, PromptRequestResponse
from pyrit.prompt_target import PromptTarget


class PromptChatTarget(PromptTarget):
Expand Down
Loading

0 comments on commit a58fa5b

Please sign in to comment.