Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rdheekonda committed May 18, 2024
1 parent a58fa5b commit 4a86726
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 8 deletions.
5 changes: 5 additions & 0 deletions doc/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ In our PyRIT framework, proper exception handling is crucial for maintaining rob
1. **Centralized Exceptions**: Use the exceptions defined in our centralized exceptions module or create new ones as necessary in the same module.
2. **Inherit `PyritException`**: Ensure any new exceptions inherit from `PyritException` to maintain consistency.
3. **Exception Processing**: Utilize the `process_exception` method to handle exceptions appropriately.
4. **Add Response Entries to Memory**: After handling an exception using the `process_exception` method. While adding response entries to memory, ensure to set the `response_type` and `error` parameter to `error` to help identify the responses in the database for easy filtering.

#### Specific Scenarios

Expand All @@ -195,6 +196,10 @@ In our PyRIT framework, proper exception handling is crucial for maintaining rob
- **Action**: Raise the exception itself to allow for proper propagation.
- **Future Learning**: Monitor these exceptions to learn and identify patterns for future enhancements and more specific exception handling.

6. **Set Error to Original Value/Converted Value**

- **Action**: After processing the exception, while adding the response entry to the DB, set the error message to the `original_value` and `converted_value`. These fields of `PromptRequestPiece` should be used to ensure proper testing.

By following these guidelines, we ensure a consistent and robust approach to exception handling across the framework.

### Running tests
Expand Down
4 changes: 2 additions & 2 deletions pyrit/exceptions/exception_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class EmptyResponseException(BadRequestException):

def __init__(self, status_code: int = 204, *, message: str = "No Content"):
super().__init__(status_code=status_code, message=message)


def pyrit_retry(func: Callable) -> Callable:
"""
Expand All @@ -71,5 +71,5 @@ def pyrit_retry(func: Callable) -> Callable:
retry=retry_if_exception_type(RateLimitError) | retry_if_exception_type(EmptyResponseException),
wait=wait_random_exponential(min=RETRY_WAIT_MIN_SECONDS, max=RETRY_WAIT_MAX_SECONDS),
after=after_log(logger, logging.INFO),
stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS)
stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS),
)(func)
2 changes: 2 additions & 0 deletions pyrit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DataTypeSerializer,
data_serializer_factory,
TextDataTypeSerializer,
ErrorDataTypeSerializer,
ImagePathDataTypeSerializer,
AudioPathDataTypeSerializer,
)
Expand All @@ -25,6 +26,7 @@
"ChatMessageListContent",
"DataTypeSerializer",
"data_serializer_factory",
"ErrorDataTypeSerializer",
"group_conversation_request_pieces_by_sequence",
"Identifier",
"ImagePathDataTypeSerializer",
Expand Down
11 changes: 11 additions & 0 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def data_serializer_factory(*, data_type: PromptDataType, value: str = None, ext
return ImagePathDataTypeSerializer(prompt_text=value)
elif data_type == "audio_path":
return AudioPathDataTypeSerializer(prompt_text=value)
elif data_type == "error":
return ErrorDataTypeSerializer(prompt_text=value)
else:
raise ValueError(f"Data type {data_type} not supported")
else:
Expand Down Expand Up @@ -168,6 +170,15 @@ def data_on_disk(self) -> bool:
return False


class ErrorDataTypeSerializer(DataTypeSerializer):
def __init__(self, *, prompt_text: str):
self.data_type = "error"
self.value = prompt_text

def data_on_disk(self) -> bool:
return False


class ImagePathDataTypeSerializer(DataTypeSerializer):
def __init__(self, *, prompt_text: str = None, extension: str = None):
self.data_type = "image_path"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyrit.models.data_type_serializer import data_serializer_factory, DataTypeSerializer
from pyrit.prompt_target import PromptChatTarget
from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry
from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS
from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,22 +243,22 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
bad_request_exception = BadRequestException(bre.status_code, message=bre.message)
resp_text = bad_request_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], error="blocked"
request=request, response_text_pieces=[resp_text], response_type="error", error="blocked"
)
except RateLimitError as rle:
# Handle the rate limit exception after exhausting the maximum number of retries.
rate_limit_exception = RateLimitException(rle.status_code, message=rle.message)
resp_text = rate_limit_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], error="error"
request=request, response_text_pieces=[resp_text], response_type="error", error="error"
)
except EmptyResponseException:
# Handle the empty response exception after exhausting the maximum number of retries.
message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries."
empty_response_exception = EmptyResponseException(message=message)
resp_text = empty_response_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], error="error"
request=request, response_text_pieces=[resp_text], response_type="error", error="error"
)

return response_entry
Expand Down
17 changes: 15 additions & 2 deletions tests/target/test_azure_openai_gptv_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ async def test_send_prompt_async_empty_response(
'{"status_code": 204, "message": "Empty response from the target even after 5 retries."}'
)
assert response.request_pieces[0].converted_value == expected_error_message
assert response.request_pieces[0].converted_value_data_type == "error"
assert response.request_pieces[0].original_value == expected_error_message
assert response.request_pieces[0].original_value_data_type == "error"
assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value
os.remove(tmp_file_name)

Expand All @@ -455,11 +458,16 @@ async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: Az
)
setattr(azure_gptv_chat_engine, "_complete_chat_async", mock_complete_chat_async)
prompt_request = PromptRequestResponse(
request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")]
request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert "Rate Limit Reached" in result.request_pieces[0].converted_value
assert "Rate Limit Reached" in result.request_pieces[0].original_value
assert result.request_pieces[0].original_value_data_type == "error"
assert result.request_pieces[0].converted_value_data_type == "error"
expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9"
assert result.request_pieces[0].original_value_sha256 == expected_sha_256


@pytest.mark.asyncio
Expand All @@ -473,11 +481,16 @@ async def test_send_prompt_async_bad_request_error(azure_gptv_chat_engine: Azure
setattr(azure_gptv_chat_engine, "_complete_chat_async", mock_complete_chat_async)

prompt_request = PromptRequestResponse(
request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")]
request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert "Bad Request Error" in result.request_pieces[0].converted_value
assert "Bad Request Error" in result.request_pieces[0].original_value
assert result.request_pieces[0].original_value_data_type == "error"
assert result.request_pieces[0].converted_value_data_type == "error"
expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77"
assert result.request_pieces[0].original_value_sha256 == expected_sha256


def test_parse_chat_completion_successful(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyrit.models import (
ImagePathDataTypeSerializer,
TextDataTypeSerializer,
ErrorDataTypeSerializer,
DataTypeSerializer,
data_serializer_factory,
)
Expand All @@ -27,6 +28,15 @@ def test_data_serializer_factory_text_with_data():
assert normalizer.data_on_disk() is False


def test_data_serializer_factory_error_with_data():
normalizer = data_serializer_factory(data_type="error", value="test")
assert isinstance(normalizer, DataTypeSerializer)
assert isinstance(normalizer, ErrorDataTypeSerializer)
assert normalizer.data_type == "error"
assert normalizer.value == "test"
assert normalizer.data_on_disk() is False


def test_data_serializer_text_read_data_throws():
normalizer = data_serializer_factory(data_type="text", value="test")
with pytest.raises(TypeError):
Expand All @@ -39,6 +49,18 @@ def test_data_serializer_text_save_data_throws():
normalizer.save_data(b"\x00")


def test_data_serializer_error_read_data_throws():
normalizer = data_serializer_factory(data_type="error", value="test")
with pytest.raises(TypeError):
normalizer.read_data()


def test_data_serializer_error_save_data_throws():
normalizer = data_serializer_factory(data_type="error", value="test")
with pytest.raises(TypeError):
normalizer.save_data(b"\x00")


def test_image_path_normalizer_factory_prompt_text_raises():
with pytest.raises(FileNotFoundError):
data_serializer_factory(data_type="image_path", value="no_real_path.txt")
Expand Down

0 comments on commit 4a86726

Please sign in to comment.