From 4a8672632c32556c4948e36576c0248192a827ff Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Fri, 17 May 2024 22:36:12 -0700 Subject: [PATCH] Address PR feedback --- doc/contributing.md | 5 +++++ pyrit/exceptions/exception_classes.py | 4 ++-- pyrit/models/__init__.py | 2 ++ pyrit/models/data_type_serializer.py | 11 ++++++++++ .../azure_openai_gptv_chat_target.py | 8 +++---- .../test_azure_openai_gptv_chat_target.py | 17 ++++++++++++-- tests/test_data_type_serializer.py | 22 +++++++++++++++++++ 7 files changed, 61 insertions(+), 8 deletions(-) diff --git a/doc/contributing.md b/doc/contributing.md index 4d351fe00..05d6681a7 100644 --- a/doc/contributing.md +++ b/doc/contributing.md @@ -173,6 +173,7 @@ In our PyRIT framework, proper exception handling is crucial for maintaining rob 1. **Centralized Exceptions**: Use the exceptions defined in our centralized exceptions module or create new ones as necessary in the same module. 2. **Inherit `PyritException`**: Ensure any new exceptions inherit from `PyritException` to maintain consistency. 3. **Exception Processing**: Utilize the `process_exception` method to handle exceptions appropriately. +4. **Add Response Entries to Memory**: After handling an exception using the `process_exception` method. While adding response entries to memory, ensure to set the `response_type` and `error` parameter to `error` to help identify the responses in the database for easy filtering. #### Specific Scenarios @@ -195,6 +196,10 @@ In our PyRIT framework, proper exception handling is crucial for maintaining rob - **Action**: Raise the exception itself to allow for proper propagation. - **Future Learning**: Monitor these exceptions to learn and identify patterns for future enhancements and more specific exception handling. +6. **Set Error to Original Value/Converted Value** + + - **Action**: After processing the exception, while adding the response entry to the DB, set the error message to the `original_value` and `converted_value`. These fields of `PromptRequestPiece` should be used to ensure proper testing. + By following these guidelines, we ensure a consistent and robust approach to exception handling across the framework. ### Running tests diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 815b15b98..7fc2ae92f 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -50,7 +50,7 @@ class EmptyResponseException(BadRequestException): def __init__(self, status_code: int = 204, *, message: str = "No Content"): super().__init__(status_code=status_code, message=message) - + def pyrit_retry(func: Callable) -> Callable: """ @@ -71,5 +71,5 @@ def pyrit_retry(func: Callable) -> Callable: retry=retry_if_exception_type(RateLimitError) | retry_if_exception_type(EmptyResponseException), wait=wait_random_exponential(min=RETRY_WAIT_MIN_SECONDS, max=RETRY_WAIT_MAX_SECONDS), after=after_log(logger, logging.INFO), - stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS) + stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS), )(func) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 178732378..913c2fb32 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -8,6 +8,7 @@ DataTypeSerializer, data_serializer_factory, TextDataTypeSerializer, + ErrorDataTypeSerializer, ImagePathDataTypeSerializer, AudioPathDataTypeSerializer, ) @@ -25,6 +26,7 @@ "ChatMessageListContent", "DataTypeSerializer", "data_serializer_factory", + "ErrorDataTypeSerializer", "group_conversation_request_pieces_by_sequence", "Identifier", "ImagePathDataTypeSerializer", diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index d38c0445b..1bd21dbca 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -22,6 +22,8 @@ def data_serializer_factory(*, data_type: PromptDataType, value: str = None, ext return ImagePathDataTypeSerializer(prompt_text=value) elif data_type == "audio_path": return AudioPathDataTypeSerializer(prompt_text=value) + elif data_type == "error": + return ErrorDataTypeSerializer(prompt_text=value) else: raise ValueError(f"Data type {data_type} not supported") else: @@ -168,6 +170,15 @@ def data_on_disk(self) -> bool: return False +class ErrorDataTypeSerializer(DataTypeSerializer): + def __init__(self, *, prompt_text: str): + self.data_type = "error" + self.value = prompt_text + + def data_on_disk(self) -> bool: + return False + + class ImagePathDataTypeSerializer(DataTypeSerializer): def __init__(self, *, prompt_text: str = None, extension: str = None): self.data_type = "image_path" diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py index f6356593b..3f9a4c859 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py @@ -19,7 +19,7 @@ from pyrit.models.data_type_serializer import data_serializer_factory, DataTypeSerializer from pyrit.prompt_target import PromptChatTarget from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry -from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS +from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS logger = logging.getLogger(__name__) @@ -243,14 +243,14 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P bad_request_exception = BadRequestException(bre.status_code, message=bre.message) resp_text = bad_request_exception.process_exception() response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], error="blocked" + request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" ) except RateLimitError as rle: # Handle the rate limit exception after exhausting the maximum number of retries. rate_limit_exception = RateLimitException(rle.status_code, message=rle.message) resp_text = rate_limit_exception.process_exception() response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], error="error" + request=request, response_text_pieces=[resp_text], response_type="error", error="error" ) except EmptyResponseException: # Handle the empty response exception after exhausting the maximum number of retries. @@ -258,7 +258,7 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P empty_response_exception = EmptyResponseException(message=message) resp_text = empty_response_exception.process_exception() response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], error="error" + request=request, response_text_pieces=[resp_text], response_type="error", error="error" ) return response_entry diff --git a/tests/target/test_azure_openai_gptv_chat_target.py b/tests/target/test_azure_openai_gptv_chat_target.py index 1f791e941..2a742f7c0 100644 --- a/tests/target/test_azure_openai_gptv_chat_target.py +++ b/tests/target/test_azure_openai_gptv_chat_target.py @@ -441,6 +441,9 @@ async def test_send_prompt_async_empty_response( '{"status_code": 204, "message": "Empty response from the target even after 5 retries."}' ) assert response.request_pieces[0].converted_value == expected_error_message + assert response.request_pieces[0].converted_value_data_type == "error" + assert response.request_pieces[0].original_value == expected_error_message + assert response.request_pieces[0].original_value_data_type == "error" assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value os.remove(tmp_file_name) @@ -455,11 +458,16 @@ async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: Az ) setattr(azure_gptv_chat_engine, "_complete_chat_async", mock_complete_chat_async) prompt_request = PromptRequestResponse( - request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] + request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")] ) result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) assert "Rate Limit Reached" in result.request_pieces[0].converted_value + assert "Rate Limit Reached" in result.request_pieces[0].original_value + assert result.request_pieces[0].original_value_data_type == "error" + assert result.request_pieces[0].converted_value_data_type == "error" + expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9" + assert result.request_pieces[0].original_value_sha256 == expected_sha_256 @pytest.mark.asyncio @@ -473,11 +481,16 @@ async def test_send_prompt_async_bad_request_error(azure_gptv_chat_engine: Azure setattr(azure_gptv_chat_engine, "_complete_chat_async", mock_complete_chat_async) prompt_request = PromptRequestResponse( - request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] + request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")] ) result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) assert "Bad Request Error" in result.request_pieces[0].converted_value + assert "Bad Request Error" in result.request_pieces[0].original_value + assert result.request_pieces[0].original_value_data_type == "error" + assert result.request_pieces[0].converted_value_data_type == "error" + expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77" + assert result.request_pieces[0].original_value_sha256 == expected_sha256 def test_parse_chat_completion_successful(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): diff --git a/tests/test_data_type_serializer.py b/tests/test_data_type_serializer.py index cebe1ccbc..ebaf095c8 100644 --- a/tests/test_data_type_serializer.py +++ b/tests/test_data_type_serializer.py @@ -8,6 +8,7 @@ from pyrit.models import ( ImagePathDataTypeSerializer, TextDataTypeSerializer, + ErrorDataTypeSerializer, DataTypeSerializer, data_serializer_factory, ) @@ -27,6 +28,15 @@ def test_data_serializer_factory_text_with_data(): assert normalizer.data_on_disk() is False +def test_data_serializer_factory_error_with_data(): + normalizer = data_serializer_factory(data_type="error", value="test") + assert isinstance(normalizer, DataTypeSerializer) + assert isinstance(normalizer, ErrorDataTypeSerializer) + assert normalizer.data_type == "error" + assert normalizer.value == "test" + assert normalizer.data_on_disk() is False + + def test_data_serializer_text_read_data_throws(): normalizer = data_serializer_factory(data_type="text", value="test") with pytest.raises(TypeError): @@ -39,6 +49,18 @@ def test_data_serializer_text_save_data_throws(): normalizer.save_data(b"\x00") +def test_data_serializer_error_read_data_throws(): + normalizer = data_serializer_factory(data_type="error", value="test") + with pytest.raises(TypeError): + normalizer.read_data() + + +def test_data_serializer_error_save_data_throws(): + normalizer = data_serializer_factory(data_type="error", value="test") + with pytest.raises(TypeError): + normalizer.save_data(b"\x00") + + def test_image_path_normalizer_factory_prompt_text_raises(): with pytest.raises(FileNotFoundError): data_serializer_factory(data_type="image_path", value="no_real_path.txt")