-
Notifications
You must be signed in to change notification settings - Fork 401
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into users/rdheekonda/add-exception-handling
- Loading branch information
Showing
15 changed files
with
69 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
# Copyright (c) Adriano Maia <[email protected]> | ||
# Licensed under the MIT license. | ||
import asyncio | ||
import concurrent.futures | ||
import logging | ||
|
||
from pyrit.chat_message_normalizer import ChatMessageNormalizer, ChatMessageNop | ||
from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer | ||
from pyrit.common import default_values, net_utility | ||
from pyrit.memory import MemoryInterface | ||
from pyrit.models import ChatMessage | ||
from pyrit.models import PromptRequestPiece, PromptRequestResponse | ||
from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse | ||
from pyrit.prompt_target import PromptChatTarget | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -36,28 +37,11 @@ def __init__( | |
self.chat_message_normalizer = chat_message_normalizer | ||
|
||
def send_prompt(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: | ||
self._validate_request(prompt_request=prompt_request) | ||
request: PromptRequestPiece = prompt_request.request_pieces[0] | ||
|
||
messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) | ||
messages.append(request.to_chat_message()) | ||
|
||
logger.info(f"Sending the following prompt to the prompt target: {self} {request}") | ||
|
||
self._memory.add_request_response_to_memory(request=prompt_request) | ||
|
||
resp = self._complete_chat( | ||
messages=messages, | ||
) | ||
|
||
if not resp: | ||
raise ValueError("The chat returned an empty response.") | ||
|
||
logger.info(f'Received the following response from the prompt target "{resp}"') | ||
|
||
response_entry = self._memory.add_response_entries_to_memory(request=request, response_text_pieces=[resp]) | ||
|
||
return response_entry | ||
""" | ||
Deprecated. Use send_prompt_async instead. | ||
""" | ||
pool = concurrent.futures.ThreadPoolExecutor() | ||
return pool.submit(asyncio.run, self.send_prompt_async(prompt_request=prompt_request)).result() | ||
|
||
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: | ||
|
||
|
@@ -84,19 +68,6 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P | |
|
||
return response_entry | ||
|
||
def _complete_chat( | ||
self, | ||
messages: list[ChatMessage], | ||
) -> str: | ||
headers = self._get_headers() | ||
payload = self._construct_http_body(messages) | ||
|
||
response = net_utility.make_request_and_raise_if_error( | ||
endpoint_uri=self.endpoint_uri, method="POST", request_body=payload, headers=headers | ||
) | ||
|
||
return response.json()["message"]["content"] | ||
|
||
async def _complete_chat_async( | ||
self, | ||
messages: list[ChatMessage], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.