From 7ff213ad89f43bfff558abe6792fe5f383d4a36a Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Wed, 21 Aug 2024 18:42:01 +0800 Subject: [PATCH] wip: feat NebulaGraph as persistant layer part 2 - [x] as BaseGraphStorage - [-] as BaseKVStorage # <--- this commit is working on this --- nano_graphrag/_storage.py | 124 ++++++++++++++++++++++++++++++++++++-- nano_graphrag/base.py | 7 +++ nano_graphrag/graphrag.py | 36 +++++++---- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/nano_graphrag/_storage.py b/nano_graphrag/_storage.py index d5c439f..564f997 100644 --- a/nano_graphrag/_storage.py +++ b/nano_graphrag/_storage.py @@ -5,7 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Union, cast, Literal +from typing import Any, Union, cast, Literal, Callable, Dict, Optional import networkx as nx from networkx.classes.reportviews import EdgeView, NodeView @@ -64,6 +64,69 @@ async def drop(self): self._data = {} + +@dataclass +class NebulaGraphIndexStorage(BaseKVStorage): + """ + NebulaGraphIndexStorage is a storage that uses NebulaGraph as the underlying storage. + + GraphRAG Indexed data is natively stored in NebulaGraph, so we map different "namespaces" under the KV abstraction + to different Graph Vertex TAGs and Edge TYPES(when applicable). + + For full_docs, we have TAG "__Document" + For text_chunks, we have TAG "__Chunk" and EDGE type "DOC_WITH_CHUNK" + For community_reports, we have TAG "__Community" and EDGE type "ENTITY_WITHIN_COMMUNITY" + """ + def __post_init__(self): + self.nebula_storage = NebulaGraphStorage( + namespace=self.namespace, + global_config=self.global_config + ) + + async def all_keys(self) -> list[str]: + # Return all keys that are in the NebulaGraph on given namespace + raise NotImplementedError("all_keys() is not implemented for NebulaGraphIndexStorage") + + async def index_done_callback(self): + await self.nebula_storage.index_done_callback() + + async def get_by_id(self, id): + # Return dict for given id, we just need to fetch data from NebulaGraph + # and cast it to dict + raise NotImplementedError("get_by_id() is not implemented for NebulaGraphIndexStorage") + + async def get_by_ids(self, ids, fields=None): + # Return list of dict for given ids, we just need to fetch data from NebulaGraph + # and cast it to dict + # if fields is not None, we need to return dict with only the fields as whitelisted + raise NotImplementedError("get_by_ids() is not implemented for NebulaGraphIndexStorage") + + async def filter_keys(self, data: list[str]) -> set[str]: + # Just return the keys that are not in the NebulaGraph + raise NotImplementedError("filter_keys() is not implemented for NebulaGraphIndexStorage") + + async def upsert(self, data: dict[str, dict]): + if self.namespace == "full_docs": + # Implement full_docs specific logic + for doc_id, doc_data in data.items(): + await self.nebula_storage.upsert_node(doc_id, {"content": doc_data["content"], "type": "full_doc"}) + elif self.namespace == "text_chunks": + # Implement text_chunks specific logic + for chunk_id, chunk_data in data.items(): + await self.nebula_storage.upsert_node(chunk_id, {"content": chunk_data["content"], "type": "text_chunk"}) + await self.nebula_storage.upsert_edge(chunk_data["full_doc_id"], chunk_id, {"type": "DOC_WITH_CHUNK"}) + elif self.namespace == "community_reports": + # Implement community_reports specific logic + for report_id, report_data in data.items(): + await self.nebula_storage.upsert_node(report_id, {"content": json.dumps(report_data), "type": "community_report"}) + else: + raise ValueError(f"Unsupported namespace for NebulaGraphIndexStorage: {self.namespace}") + + async def drop(self): + # Implement based on NebulaGraph query + raise NotImplementedError("drop() is not implemented for NebulaGraphIndexStorage") + + @dataclass class NanoVectorDBStorage(BaseVectorStorage): @@ -336,6 +399,7 @@ async def _node2vec_embed(self): @dataclass class NebulaGraphStorage(BaseGraphStorage): + # TODO, implement configration via global_config["addon_params"] # credentials space: str = "nano_graphrag" use_tls: bool = False @@ -789,7 +853,7 @@ async def get_node_edges(self, source_node_id: str) -> list[dict]: except Exception as e: raise RuntimeError(f"Failed to get edges for node {source_node_id}: {e}") from e - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str], label: Optional[str] = None): if node_id is None or not isinstance(node_id, str) or node_id == "": raise ValueError(f"Invalid node_id {node_id}") if node_data is None or not isinstance(node_data, dict): @@ -798,7 +862,7 @@ async def upsert_node(self, node_id: str, node_data: dict[str, str]): raise ValueError(f"Invalid node_data {node_data}") from uuid import uuid4 - label = self.INIT_VERTEX_TYPE + label = label or self.INIT_VERTEX_TYPE prop_all_names = list(node_data.keys()) prop_name = ",".join( @@ -826,7 +890,7 @@ async def upsert_node(self, node_id: str, node_data: dict[str, str]): if not result.is_succeeded(): raise RuntimeError(f"Failed to upsert node {node_id}: {result} with query {query}") - async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str], label: Optional[str] = None): if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "": raise ValueError(f"Invalid source_node_id {source_node_id}") if target_node_id is None or not isinstance(target_node_id, str) or target_node_id == "": @@ -837,7 +901,7 @@ async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: raise ValueError(f"Invalid edge_data {edge_data}") from uuid import uuid4 - label = self.INIT_EDGE_TYPE + label = label or self.INIT_EDGE_TYPE prop_all_names = list(edge_data.keys()) prop_name = ",".join( @@ -995,3 +1059,53 @@ async def _node2vec_embed(self): async def index_done_callback(self): # TODO: introduce cache mechnism, then we could leverage this callback pass + + + +@dataclass +class StorageProfile: + full_docs: Callable + text_chunks: Callable + llm_response_cache: Callable + community_reports: Callable + chunk_entity_relation: Callable + entities: Callable + +@dataclass +class StorageFactory: + STORAGE_PROFILES: Dict[str, StorageProfile] = { + "local": StorageProfile( + full_docs=JsonKVStorage, + text_chunks=JsonKVStorage, + llm_response_cache=JsonKVStorage, + community_reports=JsonKVStorage, + chunk_entity_relation=NetworkXStorage, + entities=MilvusLiteStorge + ), + "nebulagraph": StorageProfile( + full_docs=NebulaGraphIndexStorage, + text_chunks=NebulaGraphIndexStorage, + llm_response_cache=JsonKVStorage, + community_reports=NebulaGraphIndexStorage, + chunk_entity_relation=NebulaGraphStorage, + entities=MilvusLiteStorge + ) + } + + @staticmethod + def get_storage( + namespace: str, + global_config: Dict[str, Any], + knowledge_store: Literal["local", "nebulagraph"] = "local", + **kwargs + ) -> Union[BaseKVStorage, BaseGraphStorage, BaseVectorStorage]: + if knowledge_store not in StorageFactory.STORAGE_PROFILES: + raise ValueError(f"Unsupported knowledge_store: {knowledge_store}") + + profile = StorageFactory.STORAGE_PROFILES[knowledge_store] + + storage_class = getattr(profile, namespace, None) + if storage_class is None: + raise ValueError(f"Unsupported namespace: {namespace}") + + return storage_class(namespace=namespace, global_config=global_config, **kwargs) diff --git a/nano_graphrag/base.py b/nano_graphrag/base.py index aa0e34b..74ee4dd 100644 --- a/nano_graphrag/base.py +++ b/nano_graphrag/base.py @@ -84,6 +84,13 @@ async def upsert(self, data: dict[str, dict]): @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): + namespace: Literal[ + "full_docs", # for full docs storage + "text_chunks", # for text chunks storage + "llm_response_cache", # for llm response cache + "community_reports", # for community reports + ] + async def all_keys(self) -> list[str]: raise NotImplementedError diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 7c27ffe..3e879bf 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Type, cast +from typing import Type, cast, Literal from ._llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding @@ -18,6 +18,7 @@ JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, + StorageFactory, ) from ._utils import EmbeddingFunc, compute_mdhash_id, limit_async_func_call, logger from .base import ( @@ -100,36 +101,47 @@ def __post_init__(self): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.full_docs = self.key_string_value_json_storage_cls( - namespace="full_docs", global_config=asdict(self) + self.full_docs = StorageFactory.get_storage( + namespace="full_docs", + global_config=asdict(self), + knowledge_store=self.knowledge_store ) - self.text_chunks = self.key_string_value_json_storage_cls( - namespace="text_chunks", global_config=asdict(self) + self.text_chunks = StorageFactory.get_storage( + namespace="text_chunks", + global_config=asdict(self), + knowledge_store=self.knowledge_store ) self.llm_response_cache = ( - self.key_string_value_json_storage_cls( - namespace="llm_response_cache", global_config=asdict(self) + StorageFactory.get_storage( + namespace="llm_response_cache", + global_config=asdict(self), + knowledge_store=self.knowledge_store ) if self.enable_llm_cache else None ) - self.community_reports = self.key_string_value_json_storage_cls( - namespace="community_reports", global_config=asdict(self) + self.community_reports = StorageFactory.get_storage( + namespace="community_reports", + global_config=asdict(self), + knowledge_store=self.knowledge_store ) - self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace="chunk_entity_relation", global_config=asdict(self) + self.chunk_entity_relation_graph = StorageFactory.get_storage( + namespace="chunk_entity_relation", + global_config=asdict(self), + knowledge_store=self.knowledge_store ) self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) self.entities_vdb = ( - self.vector_db_storage_cls( + StorageFactory.get_storage( namespace="entities", global_config=asdict(self), + knowledge_store=self.knowledge_store, embedding_func=self.embedding_func, meta_fields={"entity_name"}, )