From 9f75e522efbe14b2c4327820b38b3f9592530203 Mon Sep 17 00:00:00 2001 From: rlundeen2 Date: Wed, 15 May 2024 16:39:25 -0700 Subject: [PATCH] fixing build --- pyrit/memory/duckdb_memory.py | 4 +-- pyrit/memory/memory_interface.py | 12 +++---- pyrit/orchestrator/scoring_orchestrator.py | 3 +- tests/memory/test_duckdb_memory.py | 33 ++++++++----------- tests/mocks.py | 4 +++ .../orchestrator/test_scoring_orchestrator.py | 4 +-- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 0e13701b7..8fb5b8aec 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -112,12 +112,12 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") return [] - def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[PromptRequestPiece]: + def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]: """ Retrieves a list of PromptRequestPiece objects that have the specified prompt ids. Args: - prompt_ids (list[int]): The prompt IDs to match. + prompt_ids (list[str]): The prompt IDs to match. Returns: list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 57cd8d406..3015f6691 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -79,15 +79,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr """ @abc.abstractmethod - def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None: + def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None: """ - Inserts embedding data into memory storage + Inserts a list of prompt request pieces into the memory storage. """ @abc.abstractmethod - def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None: + def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None: """ - Inserts a list of prompt request pieces into the memory storage. + Inserts embedding data into memory storage """ @abc.abstractmethod @@ -116,7 +116,7 @@ def get_conversation(self, *, conversation_id: str) -> list[PromptRequestRespons return group_conversation_request_pieces_by_sequence(request_pieces=request_pieces) @abc.abstractmethod - def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[PromptRequestPiece]: + def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]: """ Retrieves a list of PromptRequestPiece objects that have the specified prompt ids. @@ -127,7 +127,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[Prom list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID. """ - def d(self, *, orchestrator_id: int) -> list[PromptRequestPiece]: + def get_prompt_request_piece_by_orchestrator_id(self, *, orchestrator_id: int) -> list[PromptRequestPiece]: """ Retrieves a list of PromptRequestPiece objects that have the specified orchestrator ID. diff --git a/pyrit/orchestrator/scoring_orchestrator.py b/pyrit/orchestrator/scoring_orchestrator.py index f5a90fb50..c5a390bc2 100644 --- a/pyrit/orchestrator/scoring_orchestrator.py +++ b/pyrit/orchestrator/scoring_orchestrator.py @@ -82,7 +82,8 @@ async def _score_prompts_batch_async(self, prompts: list[PromptRequestPiece], sc batch_results = await asyncio.gather(*tasks) results.extend(batch_results) - return results + # results is a list[list[str]] and needs to be flattened + return [score for sublist in results for score in sublist] def _chunked_prompts(self, prompts, size): for i in range(0, len(prompts), size): diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 04af80e3c..9862896ee 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -13,19 +13,19 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql.sqltypes import NullType -from pyrit.memory.memory_interface import MemoryInterface +from pyrit.memory.duckdb_memory import DuckDBMemory from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.models import PromptRequestPiece, Score from pyrit.orchestrator.orchestrator_class import Orchestrator from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget -from tests.mocks import get_memory_interface +from tests.mocks import get_duckdb_memory from tests.mocks import get_sample_conversation_entries @pytest.fixture -def memory_interface() -> Generator[MemoryInterface, None, None]: - yield from get_memory_interface() +def memory_interface() -> Generator[DuckDBMemory, None, None]: + yield from get_duckdb_memory() @pytest.fixture @@ -387,7 +387,7 @@ def test_get_memories_with_json_properties(memory_interface): assert labels["normalizer_id"] == "id1" -def test_get_memories_with_orchestrator_id(memory_interface: MemoryInterface): +def test_get_memories_with_orchestrator_id(memory_interface: DuckDBMemory): # Define a specific normalizer_id orchestrator1 = Orchestrator() orchestrator2 = Orchestrator() @@ -423,23 +423,18 @@ def test_get_memories_with_orchestrator_id(memory_interface: MemoryInterface): ), ] - # Insert the ConversationData entries using the _insert_entries method within a session - with memory_interface.get_session() as session: - memory_interface._insert_entries(entries=entries) - session.commit() # Ensure all entries are committed to the database + memory_interface._insert_entries(entries=entries) - orchestrator1_id = orchestrator1.get_identifier()["id"] + orchestrator1_id = int(orchestrator1.get_identifier()["id"]) - # Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id - retrieved_entries = memory_interface.get_prompt_request_piece_by_orchestrator_id( - orchestrator_id=orchestrator1_id - ) + # Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id + retrieved_entries = memory_interface.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=orchestrator1_id) - # Verify that the retrieved entries match the expected normalizer_id - assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id - for retrieved_entry in retrieved_entries: - assert retrieved_entry.orchestrator_identifier["id"] == str(orchestrator1_id) - assert "Hello" in retrieved_entry.original_value # Basic check to ensure content is as expected + # Verify that the retrieved entries match the expected normalizer_id + assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id + for retrieved_entry in retrieved_entries: + assert retrieved_entry.orchestrator_identifier["id"] == str(orchestrator1_id) + assert "Hello" in retrieved_entry.original_value # Basic check to ensure content is as expected def test_update_entries_by_conversation_id(memory_interface, sample_conversation_entries): diff --git a/tests/mocks.py b/tests/mocks.py index bf540ba36..4378fa57a 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -89,6 +89,10 @@ def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: def get_memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_duckdb_memory() + + +def get_duckdb_memory() -> Generator[DuckDBMemory, None, None]: # Create an in-memory DuckDB engine duckdb_memory = DuckDBMemory(db_path=":memory:") diff --git a/tests/orchestrator/test_scoring_orchestrator.py b/tests/orchestrator/test_scoring_orchestrator.py index 704b71fa1..c19bdad65 100644 --- a/tests/orchestrator/test_scoring_orchestrator.py +++ b/tests/orchestrator/test_scoring_orchestrator.py @@ -39,7 +39,7 @@ async def test_score_prompts_by_orchestrator_only_responses(sample_conversations orchestrator = ScoringOrchestrator(memory=memory) with patch.object(orchestrator, "_score_prompts_batch_async", new_callable=AsyncMock) as mock_score: - await orchestrator.score_prompts_by_orchestrator_id_async(scorer=MagicMock(), orchestrator_ids=["id1"]) + await orchestrator.score_prompts_by_orchestrator_id_async(scorer=MagicMock(), orchestrator_ids=[123]) mock_score.assert_called_once() _, called_kwargs = mock_score.call_args @@ -57,7 +57,7 @@ async def test_score_prompts_by_orchestrator_includes_requests(sample_conversati with patch.object(orchestrator, "_score_prompts_batch_async", new_callable=AsyncMock) as mock_score: await orchestrator.score_prompts_by_orchestrator_id_async( - scorer=MagicMock(), orchestrator_ids=["id1"], responses_only=False + scorer=MagicMock(), orchestrator_ids=[123], responses_only=False ) mock_score.assert_called_once()