Skip to content

Commit

Permalink
fixing build
Browse files Browse the repository at this point in the history
  • Loading branch information
rlundeen2 committed May 15, 2024
1 parent 2cf5b7d commit 9f75e52
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []

def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[PromptRequestPiece]:
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified prompt ids.
Args:
prompt_ids (list[int]): The prompt IDs to match.
prompt_ids (list[str]): The prompt IDs to match.
Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
Expand Down
12 changes: 6 additions & 6 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr
"""

@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
"""
Inserts embedding data into memory storage
Inserts a list of prompt request pieces into the memory storage.
"""

@abc.abstractmethod
def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
Inserts embedding data into memory storage
"""

@abc.abstractmethod
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_conversation(self, *, conversation_id: str) -> list[PromptRequestRespons
return group_conversation_request_pieces_by_sequence(request_pieces=request_pieces)

@abc.abstractmethod
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[PromptRequestPiece]:
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified prompt ids.
Expand All @@ -127,7 +127,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[int]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""

def d(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
def get_prompt_request_piece_by_orchestrator_id(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified orchestrator ID.
Expand Down
3 changes: 2 additions & 1 deletion pyrit/orchestrator/scoring_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ async def _score_prompts_batch_async(self, prompts: list[PromptRequestPiece], sc
batch_results = await asyncio.gather(*tasks)
results.extend(batch_results)

return results
# results is a list[list[str]] and needs to be flattened
return [score for sublist in results for score in sublist]

def _chunked_prompts(self, prompts, size):
for i in range(0, len(prompts), size):
Expand Down
33 changes: 14 additions & 19 deletions tests/memory/test_duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql.sqltypes import NullType

from pyrit.memory.memory_interface import MemoryInterface
from pyrit.memory.duckdb_memory import DuckDBMemory
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData
from pyrit.models import PromptRequestPiece, Score
from pyrit.orchestrator.orchestrator_class import Orchestrator
from pyrit.prompt_converter.base64_converter import Base64Converter
from pyrit.prompt_target.text_target import TextTarget
from tests.mocks import get_memory_interface
from tests.mocks import get_duckdb_memory
from tests.mocks import get_sample_conversation_entries


@pytest.fixture
def memory_interface() -> Generator[MemoryInterface, None, None]:
yield from get_memory_interface()
def memory_interface() -> Generator[DuckDBMemory, None, None]:
yield from get_duckdb_memory()


@pytest.fixture
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_get_memories_with_json_properties(memory_interface):
assert labels["normalizer_id"] == "id1"


def test_get_memories_with_orchestrator_id(memory_interface: MemoryInterface):
def test_get_memories_with_orchestrator_id(memory_interface: DuckDBMemory):
# Define a specific normalizer_id
orchestrator1 = Orchestrator()
orchestrator2 = Orchestrator()
Expand Down Expand Up @@ -423,23 +423,18 @@ def test_get_memories_with_orchestrator_id(memory_interface: MemoryInterface):
),
]

# Insert the ConversationData entries using the _insert_entries method within a session
with memory_interface.get_session() as session:
memory_interface._insert_entries(entries=entries)
session.commit() # Ensure all entries are committed to the database
memory_interface._insert_entries(entries=entries)

orchestrator1_id = orchestrator1.get_identifier()["id"]
orchestrator1_id = int(orchestrator1.get_identifier()["id"])

# Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id
retrieved_entries = memory_interface.get_prompt_request_piece_by_orchestrator_id(
orchestrator_id=orchestrator1_id
)
# Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id
retrieved_entries = memory_interface.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=orchestrator1_id)

# Verify that the retrieved entries match the expected normalizer_id
assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id
for retrieved_entry in retrieved_entries:
assert retrieved_entry.orchestrator_identifier["id"] == str(orchestrator1_id)
assert "Hello" in retrieved_entry.original_value # Basic check to ensure content is as expected
# Verify that the retrieved entries match the expected normalizer_id
assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id
for retrieved_entry in retrieved_entries:
assert retrieved_entry.orchestrator_identifier["id"] == str(orchestrator1_id)
assert "Hello" in retrieved_entry.original_value # Basic check to ensure content is as expected


def test_update_entries_by_conversation_id(memory_interface, sample_conversation_entries):
Expand Down
4 changes: 4 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:


def get_memory_interface() -> Generator[MemoryInterface, None, None]:
yield from get_duckdb_memory()


def get_duckdb_memory() -> Generator[DuckDBMemory, None, None]:
# Create an in-memory DuckDB engine
duckdb_memory = DuckDBMemory(db_path=":memory:")

Expand Down
4 changes: 2 additions & 2 deletions tests/orchestrator/test_scoring_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def test_score_prompts_by_orchestrator_only_responses(sample_conversations
orchestrator = ScoringOrchestrator(memory=memory)

with patch.object(orchestrator, "_score_prompts_batch_async", new_callable=AsyncMock) as mock_score:
await orchestrator.score_prompts_by_orchestrator_id_async(scorer=MagicMock(), orchestrator_ids=["id1"])
await orchestrator.score_prompts_by_orchestrator_id_async(scorer=MagicMock(), orchestrator_ids=[123])

mock_score.assert_called_once()
_, called_kwargs = mock_score.call_args
Expand All @@ -57,7 +57,7 @@ async def test_score_prompts_by_orchestrator_includes_requests(sample_conversati

with patch.object(orchestrator, "_score_prompts_batch_async", new_callable=AsyncMock) as mock_score:
await orchestrator.score_prompts_by_orchestrator_id_async(
scorer=MagicMock(), orchestrator_ids=["id1"], responses_only=False
scorer=MagicMock(), orchestrator_ids=[123], responses_only=False
)

mock_score.assert_called_once()
Expand Down

0 comments on commit 9f75e52

Please sign in to comment.