Skip to content

Commit

Permalink
Updated metrics to compute all relationships at once, updated prompt …
Browse files Browse the repository at this point in the history
…instructions that works for qwen2-7b
  • Loading branch information
NumberChiffre committed Sep 21, 2024
1 parent da9812f commit e0f1a6d
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 512 deletions.
45 changes: 34 additions & 11 deletions nano_graphrag/entity_extraction/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union
import pickle
import asyncio
from openai import BadRequestError
from collections import defaultdict
import dspy
from nano_graphrag._storage import BaseGraphStorage
Expand All @@ -11,7 +12,7 @@
)
from nano_graphrag.prompt import PROMPTS
from nano_graphrag._utils import logger, compute_mdhash_id
from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor
from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert


Expand All @@ -20,31 +21,53 @@ async def generate_dataset(
filepath: str,
save_dataset: bool = True
) -> list[dspy.Example]:
entity_extractor = EntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor()
ordered_chunks = list(chunks.items())
already_processed = 0
already_entities = 0
already_relations = 0

async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> dspy.Example:
nonlocal already_processed, already_entities, already_relations
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
try:
prediction = await asyncio.to_thread(
entity_extractor, input_text=content
)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []
example = dspy.Example(
input_text=content,
entities=prediction.entities,
relationships=prediction.relationships
entities=entities,
relationships=relationships
).with_inputs("input_text")
already_entities += len(entities)
already_relations += len(relationships)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return example

examples = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
filtered_examples = [example for example in examples if len(example.entities) > 0 and len(example.relationships) > 0]
num_filtered_examples = len(examples) - len(filtered_examples)
if save_dataset:
with open(filepath, 'wb') as f:
pickle.dump(examples, f)
logger.info(f"Saved {len(examples)} examples with keys: {examples[0].keys()}")
pickle.dump(filtered_examples, f)
logger.info(f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples")

return examples
return filtered_examples


async def extract_entities_dspy(
Expand All @@ -53,7 +76,7 @@ async def extract_entities_dspy(
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
entity_extractor = EntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor()

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
Expand Down
89 changes: 26 additions & 63 deletions nano_graphrag/entity_extraction/metric.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,37 @@
import dspy
import numpy as np
from nano_graphrag.entity_extraction.module import Relationship


class AssessRelationship(dspy.Signature):
class AssessRelationships(dspy.Signature):
"""
Crucial considerations when assessing the similarity of two relationships:
- Take the "src_id" and "tgt_id" fields into account as the source and target entities are crucial for assessing the relationship similarity.
- Take the "description" field into account as it contains detailed information about the relationship.
Assess the similarity between gold and predicted relationships:
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
2. For matched pairs, compare:
a) Description similarity (semantic meaning)
b) Weight similarity
c) Order similarity
3. Consider unmatched relationships as penalties.
4. Aggregate scores, accounting for precision and recall.
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
Key considerations:
- Prioritize matching based on entity pairs over exact string matches.
- Use semantic similarity for descriptions rather than exact matches.
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
- Balance the impact of matched and unmatched relationships in the final score.
"""

gold_relationship = dspy.InputField(
desc="""
The gold-standard relationship to compare against.
gold_relationships: list[Relationship] = dspy.InputField(desc="The gold-standard relationships to compare against.")
predicted_relationships: list[Relationship] = dspy.InputField(desc="The predicted relationships to compare against the gold-standard relationships.")
similarity_score: float = dspy.OutputField(desc="Similarity score between 0 and 1, with 1 being the highest similarity.")

Format:
{
"relationships": [
{
"src_id": "SOURCE ENTITY",
"tgt_id": "TARGET ENTITY",
"description": "Detailed description of the relationship",
"weight": "Weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.",
"order": "Order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order, etc."
}
]
}
"""
)
predicted_relationship = dspy.InputField(
desc="""
The predicted relationship to compare against.

Format:
{
"relationships": [
{
"src_id": "SOURCE ENTITY",
"tgt_id": "TARGET ENTITY",
"description": "Detailed description of the relationship",
"weight": "Weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.",
"order": "Order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order, etc."
}
]
}
"""
)
similarity_score = dspy.OutputField(
desc="""
Similarity score of the predicted relationship to the gold-standard relationship between 0 and 1, 1 being the highest similarity
"""
)


def relationship_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
similarity_scores = []
model = dspy.ChainOfThought(AssessRelationship)

for gold_rel, pred_rel in zip(gold['relationships'], pred['relationships']):
assessment = model(
gold_relationship=gold_rel,
predicted_relationship=pred_rel
)

try:
score = float(assessment.similarity_score)
similarity_scores.append(score)
except ValueError:
similarity_scores.append(0.0)

return np.mean(similarity_scores) if similarity_scores else 0.0
def relationships_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
model = dspy.TypedChainOfThought(AssessRelationships)
gold_relationships = [Relationship(**item) for item in gold['relationships']]
predicted_relationships = [Relationship(**item) for item in pred['relationships']]
similarity_score = float(model(gold_relationships=gold_relationships, predicted_relationships=predicted_relationships).similarity_score)
return similarity_score


def entity_recall_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
Expand Down
Loading

0 comments on commit e0f1a6d

Please sign in to comment.