From 02ed3e5022248ca947255a46bdf5a48614572c55 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Fri, 17 May 2024 10:59:39 -0400 Subject: [PATCH] MAINT: Deprecated send_prompt methods (#204) Co-authored-by: Martin Pouliot --- doc/code/targets/azure_ml_chat.ipynb | 2 +- doc/code/targets/azure_ml_chat.py | 2 +- doc/code/targets/prompt_targets.ipynb | 4 +- doc/code/targets/prompt_targets.py | 4 +- doc/how_to_guide.ipynb | 6 +- doc/how_to_guide.py | 6 +- .../azure_openai_completion_target.py | 4 +- .../azure_ml_chat_target.py | 76 +++---------------- .../prompt_chat_target/ollama_chat_target.py | 47 +++--------- .../prompt_chat_target/openai_chat_target.py | 9 +-- .../prompt_chat_target/prompt_chat_target.py | 6 +- tests/target/test_aml_online_endpoint_chat.py | 30 ++++---- tests/target/test_openai_chat_target.py | 28 ------- tests/target/test_prompt_target.py | 23 +++--- tests/test_ollama_chat_target.py | 21 ----- 15 files changed, 69 insertions(+), 199 deletions(-) diff --git a/doc/code/targets/azure_ml_chat.ipynb b/doc/code/targets/azure_ml_chat.ipynb index cc8379350..09f5e9dfe 100644 --- a/doc/code/targets/azure_ml_chat.ipynb +++ b/doc/code/targets/azure_ml_chat.ipynb @@ -74,7 +74,7 @@ ").to_prompt_request_response()\n", "\n", "with AzureMLChatTarget() as azure_ml_chat_target:\n", - " print(azure_ml_chat_target.send_prompt(prompt_request=request))" + " print(await azure_ml_chat_target.send_prompt_async(prompt_request=request))" ] }, { diff --git a/doc/code/targets/azure_ml_chat.py b/doc/code/targets/azure_ml_chat.py index d25a9b27c..1e48d5662 100644 --- a/doc/code/targets/azure_ml_chat.py +++ b/doc/code/targets/azure_ml_chat.py @@ -39,7 +39,7 @@ ).to_prompt_request_response() with AzureMLChatTarget() as azure_ml_chat_target: - print(azure_ml_chat_target.send_prompt(prompt_request=request)) + print(await azure_ml_chat_target.send_prompt_async(prompt_request=request)) # type: ignore # %% [markdown] # diff --git a/doc/code/targets/prompt_targets.ipynb b/doc/code/targets/prompt_targets.ipynb index 039abc6d9..605d3de7c 100644 --- a/doc/code/targets/prompt_targets.ipynb +++ b/doc/code/targets/prompt_targets.ipynb @@ -103,7 +103,7 @@ "# When `use_aad_auth=True`, ensure the user has 'Cognitive Service OpenAI User' role assigned on the AOAI Resource\n", "# and `az login` is used to authenticate with the correct identity\n", "with AzureOpenAIChatTarget(use_aad_auth=False) as azure_openai_chat_target:\n", - " print(azure_openai_chat_target.send_prompt(prompt_request=request))" + " print(await azure_openai_chat_target.send_prompt_async(prompt_request=request))" ] }, { @@ -162,7 +162,7 @@ " sas_token=os.environ.get(\"AZURE_STORAGE_ACCOUNT_SAS_TOKEN\"),\n", ") as abs_prompt_target:\n", "\n", - " print(abs_prompt_target.send_prompt(prompt_request=request))" + " print(await abs_prompt_target.send_prompt_async(prompt_request=request))" ] } ], diff --git a/doc/code/targets/prompt_targets.py b/doc/code/targets/prompt_targets.py index 832bb8423..4f7b86b3d 100644 --- a/doc/code/targets/prompt_targets.py +++ b/doc/code/targets/prompt_targets.py @@ -47,7 +47,7 @@ # When `use_aad_auth=True`, ensure the user has 'Cognitive Service OpenAI User' role assigned on the AOAI Resource # and `az login` is used to authenticate with the correct identity with AzureOpenAIChatTarget(use_aad_auth=False) as azure_openai_chat_target: - print(azure_openai_chat_target.send_prompt(prompt_request=request)) + print(await azure_openai_chat_target.send_prompt_async(prompt_request=request)) # type: ignore # %% [markdown] # The `AzureBlobStorageTarget` inherits from `PromptTarget`, meaning it has functionality to send prompts. In contrast to `PromptChatTarget`s, `PromptTarget`s do not interact with chat assistants. @@ -79,4 +79,4 @@ sas_token=os.environ.get("AZURE_STORAGE_ACCOUNT_SAS_TOKEN"), ) as abs_prompt_target: - print(abs_prompt_target.send_prompt(prompt_request=request)) + print(await abs_prompt_target.send_prompt_async(prompt_request=request)) # type: ignore diff --git a/doc/how_to_guide.ipynb b/doc/how_to_guide.ipynb index b33295810..10ffb7e83 100644 --- a/doc/how_to_guide.ipynb +++ b/doc/how_to_guide.ipynb @@ -69,7 +69,7 @@ " role=\"user\",\n", " original_value=\"this is a test prompt\",\n", " ).to_prompt_request_response()\n", - " target_llm.send_prompt(prompt_request=request)" + " await target_llm.send_prompt_async(prompt_request=request)" ] }, { @@ -164,7 +164,7 @@ "- run a single turn of the attack strategy or\n", "- try to achieve the goal as specified in the attack strategy which may take multiple turns.\n", "\n", - "The single turn is executed with the `send_prompt()` method. It generates the prompt using the red\n", + "The single turn is executed with the `send_prompt_async()` method. It generates the prompt using the red\n", "teaming LLM and sends it to the target.\n", "The full execution of the attack strategy over potentially multiple turns requires a mechanism\n", "to determine if the goal has been achieved.\n", @@ -569,7 +569,7 @@ " # or the maximum number of turns is reached.\n", " red_teaming_orchestrator.apply_attack_strategy_until_completion(max_turns=5)\n", "\n", - " # Alternatively, use send_prompt() to generate just a single turn of the attack strategy." + " # Alternatively, use send_prompt_async() to generate just a single turn of the attack strategy." ] }, { diff --git a/doc/how_to_guide.py b/doc/how_to_guide.py index a6033a98f..8289dc7f5 100644 --- a/doc/how_to_guide.py +++ b/doc/how_to_guide.py @@ -49,7 +49,7 @@ role="user", original_value="this is a test prompt", ).to_prompt_request_response() - target_llm.send_prompt(prompt_request=request) + await target_llm.send_prompt_async(prompt_request=request) # type: ignore # %% [markdown] # To expand to a wider variety of harms, it may be beneficial to write prompt templates instead of the @@ -101,7 +101,7 @@ # - run a single turn of the attack strategy or # - try to achieve the goal as specified in the attack strategy which may take multiple turns. # -# The single turn is executed with the `send_prompt()` method. It generates the prompt using the red +# The single turn is executed with the `send_prompt_async()` method. It generates the prompt using the red # teaming LLM and sends it to the target. # The full execution of the attack strategy over potentially multiple turns requires a mechanism # to determine if the goal has been achieved. @@ -158,7 +158,7 @@ # or the maximum number of turns is reached. red_teaming_orchestrator.apply_attack_strategy_until_completion(max_turns=5) - # Alternatively, use send_prompt() to generate just a single turn of the attack strategy. + # Alternatively, use send_prompt_async() to generate just a single turn of the attack strategy. # %% [markdown] # Going a step further, we can generalize the attack strategy into templates as mentioned in an earlier diff --git a/pyrit/prompt_target/azure_openai_completion_target.py b/pyrit/prompt_target/azure_openai_completion_target.py index 0f733b03d..e35e28336 100644 --- a/pyrit/prompt_target/azure_openai_completion_target.py +++ b/pyrit/prompt_target/azure_openai_completion_target.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures import logging + from openai import AsyncAzureOpenAI from openai.types.completion import Completion @@ -14,7 +15,6 @@ from pyrit.models.prompt_request_response import PromptRequestResponse from pyrit.prompt_target.prompt_target import PromptTarget - logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ def __init__( def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: """ - Sends a normalized prompt to the prompt target and adds the request and response to memory + Deprecated. Sends a normalized prompt to the prompt target and adds the request and response to memory """ pool = concurrent.futures.ThreadPoolExecutor() return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result() diff --git a/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py index 4ad89507f..e7ec11577 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio +import concurrent.futures import logging -from pyrit.chat_message_normalizer import ChatMessageNormalizer, ChatMessageNop +from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer from pyrit.common import default_values, net_utility from pyrit.memory import MemoryInterface -from pyrit.models import PromptRequestResponse -from pyrit.models import ChatMessage +from pyrit.models import ChatMessage, PromptRequestResponse from pyrit.prompt_target import PromptChatTarget logger = logging.getLogger(__name__) @@ -66,32 +67,11 @@ def __init__( self._repetition_penalty = repetition_penalty def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: - - self._validate_request(prompt_request=prompt_request) - request = prompt_request.request_pieces[0] - - messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) - messages.append(request.to_chat_message()) - - self._memory.add_request_response_to_memory(request=prompt_request) - - logger.info(f"Sending the following prompt to the prompt target: {request}") - - resp_text = self._complete_chat( - messages=messages, - temperature=self._temperature, - top_p=self._top_p, - repetition_penalty=self._repetition_penalty, - ) - - if not resp_text: - raise ValueError("The chat returned an empty response.") - - logger.info(f'Received the following response from the prompt target "{resp_text}"') - - response_entry = self._memory.add_response_entries_to_memory(request=request, response_text_pieces=[resp_text]) - - return response_entry + """ + Deprecated. Use send_prompt_async instead. + """ + pool = concurrent.futures.ThreadPoolExecutor() + return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result() async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: @@ -121,44 +101,6 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P return response_entry - def _complete_chat( - self, - messages: list[ChatMessage], - max_tokens: int = 400, - temperature: float = 1.0, - top_p: int = 1, - repetition_penalty: float = 1.2, - ) -> str: - """ - Completes a chat interaction by generating a response to the given input prompt. - - This is a synchronous wrapper for the asynchronous _generate_and_extract_response method. - - Args: - messages (list[ChatMessage]): The chat messages objects containing the role and content. - max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 400. - temperature (float, optional): Controls randomness in the response generation. - Defaults to 1.0. 1 is more random, 0 is less. - top_p (int, optional): Controls diversity of the response generation. Defaults to 1. - 1 is more random, 0 is less. - repetition_penalty (float, optional): Controls repetition in the response generation. - Defaults to 1.2. - - Raises: - Exception: For any errors during the process. - - Returns: - str: The generated response message. - """ - headers = self._get_headers() - payload = self._construct_http_body(messages, max_tokens, temperature, top_p, repetition_penalty) - - response = net_utility.make_request_and_raise_if_error( - endpoint_uri=self.endpoint_uri, method="POST", request_body=payload, headers=headers - ) - - return response.json()["output"] - async def _complete_chat_async( self, messages: list[ChatMessage], diff --git a/pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py b/pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py index a3cde5e0d..51b3bf4e7 100644 --- a/pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py @@ -1,12 +1,13 @@ # Copyright (c) Adriano Maia # Licensed under the MIT license. +import asyncio +import concurrent.futures import logging -from pyrit.chat_message_normalizer import ChatMessageNormalizer, ChatMessageNop +from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer from pyrit.common import default_values, net_utility from pyrit.memory import MemoryInterface -from pyrit.models import ChatMessage -from pyrit.models import PromptRequestPiece, PromptRequestResponse +from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse from pyrit.prompt_target import PromptChatTarget logger = logging.getLogger(__name__) @@ -36,28 +37,11 @@ def __init__( self.chat_message_normalizer = chat_message_normalizer def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: - self._validate_request(prompt_request=prompt_request) - request: PromptRequestPiece = prompt_request.request_pieces[0] - - messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) - messages.append(request.to_chat_message()) - - logger.info(f"Sending the following prompt to the prompt target: {self} {request}") - - self._memory.add_request_response_to_memory(request=prompt_request) - - resp = self._complete_chat( - messages=messages, - ) - - if not resp: - raise ValueError("The chat returned an empty response.") - - logger.info(f'Received the following response from the prompt target "{resp}"') - - response_entry = self._memory.add_response_entries_to_memory(request=request, response_text_pieces=[resp]) - - return response_entry + """ + Deprecated. Use send_prompt_async instead. + """ + pool = concurrent.futures.ThreadPoolExecutor() + return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result() async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: @@ -84,19 +68,6 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P return response_entry - def _complete_chat( - self, - messages: list[ChatMessage], - ) -> str: - headers = self._get_headers() - payload = self._construct_http_body(messages) - - response = net_utility.make_request_and_raise_if_error( - endpoint_uri=self.endpoint_uri, method="POST", request_body=payload, headers=headers - ) - - return response.json()["message"]["content"] - async def _complete_chat_async( self, messages: list[ChatMessage], diff --git a/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py b/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py index 43020cb08..2906b3946 100644 --- a/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import abstractmethod -import logging import json +import logging +from abc import abstractmethod from typing import Optional from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI @@ -12,8 +12,7 @@ from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential from pyrit.common import default_values from pyrit.memory import MemoryInterface -from pyrit.models import ChatMessage -from pyrit.models import PromptRequestResponse, PromptRequestPiece +from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse from pyrit.prompt_target import PromptChatTarget logger = logging.getLogger(__name__) @@ -37,7 +36,7 @@ def __init__(self) -> None: pass def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: - + self._validate_request(prompt_request=prompt_request) request: PromptRequestPiece = prompt_request.request_pieces[0] messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) diff --git a/pyrit/prompt_target/prompt_chat_target/prompt_chat_target.py b/pyrit/prompt_target/prompt_chat_target/prompt_chat_target.py index 9932d064e..b202ed3ae 100644 --- a/pyrit/prompt_target/prompt_chat_target/prompt_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/prompt_chat_target.py @@ -1,10 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - from typing import Optional -from pyrit.models import PromptRequestResponse, PromptRequestPiece -from pyrit.prompt_target import PromptTarget + from pyrit.memory import MemoryInterface +from pyrit.models import PromptRequestPiece, PromptRequestResponse +from pyrit.prompt_target import PromptTarget class PromptChatTarget(PromptTarget): diff --git a/tests/target/test_aml_online_endpoint_chat.py b/tests/target/test_aml_online_endpoint_chat.py index 7c630c790..c23cb0e9b 100644 --- a/tests/target/test_aml_online_endpoint_chat.py +++ b/tests/target/test_aml_online_endpoint_chat.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import os import pytest @@ -54,23 +54,25 @@ def test_get_headers_with_valid_api_key(aml_online_chat: AzureMLChatTarget): assert aml_online_chat._get_headers() == expected_headers -def test_complete_chat(aml_online_chat: AzureMLChatTarget): +@pytest.mark.asyncio +async def test_complete_chat_async(aml_online_chat: AzureMLChatTarget): messages = [ ChatMessage(role="user", content="user content"), ] - with patch("pyrit.common.net_utility.make_request_and_raise_if_error") as mock: + with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async") as mock: mock_response = Mock() mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response - response = aml_online_chat._complete_chat(messages) + response = await aml_online_chat._complete_chat_async(messages) assert response == "extracted response" mock.assert_called_once() # The None parameter checks the default is the same as ChatMessageNop +@pytest.mark.asyncio @pytest.mark.parametrize("message_normalizer", [None, ChatMessageNop()]) -def test_complete_chat_with_nop_normalizer( +async def test_complete_chat_async_with_nop_normalizer( aml_online_chat: AzureMLChatTarget, message_normalizer: ChatMessageNormalizer ): if message_normalizer: @@ -81,11 +83,11 @@ def test_complete_chat_with_nop_normalizer( ChatMessage(role="user", content="user content"), ] - with patch("pyrit.common.net_utility.make_request_and_raise_if_error") as mock: + with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock: mock_response = Mock() mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response - response = aml_online_chat._complete_chat(messages) + response = await aml_online_chat._complete_chat_async(messages) assert response == "extracted response" args, kwargs = mock.call_args @@ -96,7 +98,8 @@ def test_complete_chat_with_nop_normalizer( assert body["input_data"]["input_string"][0]["role"] == "system" -def test_complete_chat_with_squashnormalizer(aml_online_chat: AzureMLChatTarget): +@pytest.mark.asyncio +async def test_complete_chat_async_with_squashnormalizer(aml_online_chat: AzureMLChatTarget): aml_online_chat.chat_message_normalizer = GenericSystemSquash() messages = [ @@ -104,11 +107,11 @@ def test_complete_chat_with_squashnormalizer(aml_online_chat: AzureMLChatTarget) ChatMessage(role="user", content="user content"), ] - with patch("pyrit.common.net_utility.make_request_and_raise_if_error") as mock: + with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock: mock_response = Mock() mock_response.json.return_value = {"output": "extracted response"} mock.return_value = mock_response - response = aml_online_chat._complete_chat(messages) + response = await aml_online_chat._complete_chat_async(messages) assert response == "extracted response" args, kwargs = mock.call_args @@ -119,17 +122,18 @@ def test_complete_chat_with_squashnormalizer(aml_online_chat: AzureMLChatTarget) assert body["input_data"]["input_string"][0]["role"] == "user" -def test_complete_chat_bad_json_response(aml_online_chat: AzureMLChatTarget): +@pytest.mark.asyncio +async def test_complete_chat_async_bad_json_response(aml_online_chat: AzureMLChatTarget): messages = [ ChatMessage(role="user", content="user content"), ] - with patch("pyrit.common.net_utility.make_request_and_raise_if_error") as mock: + with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock) as mock: mock_response = Mock() mock_response.json.return_value = {"bad response"} mock.return_value = mock_response with pytest.raises(TypeError): - aml_online_chat._complete_chat(messages) + await aml_online_chat._complete_chat_async(messages) @pytest.mark.asyncio diff --git a/tests/target/test_openai_chat_target.py b/tests/target/test_openai_chat_target.py index e1deb1545..6e41670c6 100644 --- a/tests/target/test_openai_chat_target.py +++ b/tests/target/test_openai_chat_target.py @@ -75,18 +75,6 @@ def prompt_request_response() -> PromptRequestResponse: ) -def execute_openai_send_prompt( - target: OpenAIChatInterface, - prompt_request_response: PromptRequestResponse, - mock_return: ChatCompletion, -): - with patch("openai.resources.chat.Completions.create") as mock_create: - mock_create.return_value = mock_return - response: PromptRequestResponse = target.send_prompt(prompt_request=prompt_request_response) - assert len(response.request_pieces) == 1 - assert response.request_pieces[0].converted_value == "hi" - - async def execute_openai_send_prompt_async( target: OpenAIChatInterface, prompt_request_response: PromptRequestResponse, @@ -117,22 +105,6 @@ async def test_openai_complete_chat_async_return( await execute_openai_send_prompt_async(openai_chat_target, prompt_request_response, openai_mock_return) -def test_azure_complete_chat_return( - openai_mock_return: ChatCompletion, - azure_chat_target: AzureOpenAIChatTarget, - prompt_request_response: PromptRequestResponse, -): - execute_openai_send_prompt(azure_chat_target, prompt_request_response, openai_mock_return) - - -def test_openai_complete_chat_return( - openai_mock_return: ChatCompletion, - openai_chat_target: OpenAIChatTarget, - prompt_request_response: PromptRequestResponse, -): - execute_openai_send_prompt(openai_chat_target, prompt_request_response, openai_mock_return) - - @pytest.mark.asyncio async def test_openai_validate_request_length( openai_chat_target: OpenAIChatTarget, sample_conversations: list[PromptRequestPiece] diff --git a/tests/target/test_prompt_target.py b/tests/target/test_prompt_target.py index fe81d6dee..d1095091f 100644 --- a/tests/target/test_prompt_target.py +++ b/tests/target/test_prompt_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from typing import Generator -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest from openai.types.chat import ChatCompletion, ChatCompletionMessage @@ -75,19 +75,20 @@ def test_set_system_prompt(azure_openai_target: AzureOpenAIChatTarget): assert chats[0].converted_value == "system prompt" -def test_send_prompt_user_no_system( +@pytest.mark.asyncio +async def test_send_prompt_user_no_system( azure_openai_target: AzureOpenAIChatTarget, openai_mock_return: ChatCompletion, sample_entries: list[PromptRequestPiece], ): - with patch("openai.resources.chat.Completions.create") as mock: + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock: mock.return_value = openai_mock_return request = sample_entries[0] request.converted_value = "hi, I am a victim chatbot, how can I help?" - azure_openai_target.send_prompt(prompt_request=PromptRequestResponse(request_pieces=[request])) + await azure_openai_target.send_prompt_async(prompt_request=PromptRequestResponse(request_pieces=[request])) chats = azure_openai_target._memory._get_prompt_pieces_with_conversation_id( conversation_id=request.conversation_id @@ -97,13 +98,14 @@ def test_send_prompt_user_no_system( assert chats[1].role == "assistant" -def test_send_prompt_with_system( +@pytest.mark.asyncio +async def test_send_prompt_with_system( azure_openai_target: AzureOpenAIChatTarget, openai_mock_return: ChatCompletion, sample_entries: list[PromptRequestPiece], ): - with patch("openai.resources.chat.Completions.create") as mock: + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock: mock.return_value = openai_mock_return azure_openai_target.set_system_prompt( @@ -117,7 +119,7 @@ def test_send_prompt_with_system( request.converted_value = "hi, I am a victim chatbot, how can I help?" request.conversation_id = "1" - azure_openai_target.send_prompt(prompt_request=PromptRequestResponse(request_pieces=[request])) + await azure_openai_target.send_prompt_async(prompt_request=PromptRequestResponse(request_pieces=[request])) chats = azure_openai_target._memory._get_prompt_pieces_with_conversation_id(conversation_id="1") assert len(chats) == 3, f"Expected 3 chats, got {len(chats)}" @@ -125,13 +127,14 @@ def test_send_prompt_with_system( assert chats[1].role == "user" -def test_send_prompt_with_system_calls_chat_complete( +@pytest.mark.asyncio +async def test_send_prompt_with_system_calls_chat_complete( azure_openai_target: AzureOpenAIChatTarget, openai_mock_return: ChatCompletion, sample_entries: list[PromptRequestPiece], ): - with patch("openai.resources.chat.Completions.create") as mock: + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock: mock.return_value = openai_mock_return azure_openai_target.set_system_prompt( @@ -145,6 +148,6 @@ def test_send_prompt_with_system_calls_chat_complete( request.converted_value = "hi, I am a victim chatbot, how can I help?" request.conversation_id = "1" - azure_openai_target.send_prompt(prompt_request=PromptRequestResponse(request_pieces=[request])) + await azure_openai_target.send_prompt_async(prompt_request=PromptRequestResponse(request_pieces=[request])) mock.assert_called_once() diff --git a/tests/test_ollama_chat_target.py b/tests/test_ollama_chat_target.py index 08e0c871f..63cc99b1e 100644 --- a/tests/test_ollama_chat_target.py +++ b/tests/test_ollama_chat_target.py @@ -64,27 +64,6 @@ async def test_ollama_complete_chat_async_return(ollama_chat_engine: OllamaChatT assert ret == " Hello." -def test_ollama_complete_chat_return(ollama_chat_engine: OllamaChatTarget): - with patch("pyrit.common.net_utility.make_request_and_raise_if_error") as mock_create: - mock_create.return_value = httpx.Response( - 200, - json={ - "model": "mistral", - "created_at": "2024-04-13T16:14:52.69602Z", - "message": {"role": "assistant", "content": " Hello."}, - "done": True, - "total_duration": 254579625, - "load_duration": 276542, - "prompt_eval_count": 20, - "prompt_eval_duration": 222911000, - "eval_count": 3, - "eval_duration": 30879000, - }, - ) - ret = ollama_chat_engine._complete_chat(messages=[ChatMessage(role="user", content="hello")]) - assert ret == " Hello." - - def test_ollama_invalid_model_raises(): os.environ[OllamaChatTarget.MODEL_NAME_ENVIRONMENT_VARIABLE] = "" with pytest.raises(ValueError):