Skip to content

Commit

Permalink
wip: feat NebulaGraph as persistant layer part 2
Browse files Browse the repository at this point in the history
- [x] as BaseGraphStorage
- [-] as BaseKVStorage # <--- this commit is working on this
  • Loading branch information
wey-gu committed Aug 21, 2024
1 parent 502d92f commit 7ff213a
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 17 deletions.
124 changes: 119 additions & 5 deletions nano_graphrag/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Union, cast, Literal
from typing import Any, Union, cast, Literal, Callable, Dict, Optional

import networkx as nx
from networkx.classes.reportviews import EdgeView, NodeView
Expand Down Expand Up @@ -64,6 +64,69 @@ async def drop(self):
self._data = {}



@dataclass
class NebulaGraphIndexStorage(BaseKVStorage):
"""
NebulaGraphIndexStorage is a storage that uses NebulaGraph as the underlying storage.
GraphRAG Indexed data is natively stored in NebulaGraph, so we map different "namespaces" under the KV abstraction
to different Graph Vertex TAGs and Edge TYPES(when applicable).
For full_docs, we have TAG "__Document"
For text_chunks, we have TAG "__Chunk" and EDGE type "DOC_WITH_CHUNK"
For community_reports, we have TAG "__Community" and EDGE type "ENTITY_WITHIN_COMMUNITY"
"""
def __post_init__(self):
self.nebula_storage = NebulaGraphStorage(
namespace=self.namespace,
global_config=self.global_config
)

async def all_keys(self) -> list[str]:
# Return all keys that are in the NebulaGraph on given namespace
raise NotImplementedError("all_keys() is not implemented for NebulaGraphIndexStorage")

async def index_done_callback(self):
await self.nebula_storage.index_done_callback()

async def get_by_id(self, id):
# Return dict for given id, we just need to fetch data from NebulaGraph
# and cast it to dict
raise NotImplementedError("get_by_id() is not implemented for NebulaGraphIndexStorage")

async def get_by_ids(self, ids, fields=None):
# Return list of dict for given ids, we just need to fetch data from NebulaGraph
# and cast it to dict
# if fields is not None, we need to return dict with only the fields as whitelisted
raise NotImplementedError("get_by_ids() is not implemented for NebulaGraphIndexStorage")

async def filter_keys(self, data: list[str]) -> set[str]:
# Just return the keys that are not in the NebulaGraph
raise NotImplementedError("filter_keys() is not implemented for NebulaGraphIndexStorage")

async def upsert(self, data: dict[str, dict]):
if self.namespace == "full_docs":
# Implement full_docs specific logic
for doc_id, doc_data in data.items():
await self.nebula_storage.upsert_node(doc_id, {"content": doc_data["content"], "type": "full_doc"})
elif self.namespace == "text_chunks":
# Implement text_chunks specific logic
for chunk_id, chunk_data in data.items():
await self.nebula_storage.upsert_node(chunk_id, {"content": chunk_data["content"], "type": "text_chunk"})
await self.nebula_storage.upsert_edge(chunk_data["full_doc_id"], chunk_id, {"type": "DOC_WITH_CHUNK"})
elif self.namespace == "community_reports":
# Implement community_reports specific logic
for report_id, report_data in data.items():
await self.nebula_storage.upsert_node(report_id, {"content": json.dumps(report_data), "type": "community_report"})
else:
raise ValueError(f"Unsupported namespace for NebulaGraphIndexStorage: {self.namespace}")

async def drop(self):
# Implement based on NebulaGraph query
raise NotImplementedError("drop() is not implemented for NebulaGraphIndexStorage")


@dataclass
class NanoVectorDBStorage(BaseVectorStorage):

Expand Down Expand Up @@ -336,6 +399,7 @@ async def _node2vec_embed(self):

@dataclass
class NebulaGraphStorage(BaseGraphStorage):
# TODO, implement configration via global_config["addon_params"]
# credentials
space: str = "nano_graphrag"
use_tls: bool = False
Expand Down Expand Up @@ -789,7 +853,7 @@ async def get_node_edges(self, source_node_id: str) -> list[dict]:
except Exception as e:
raise RuntimeError(f"Failed to get edges for node {source_node_id}: {e}") from e

async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: dict[str, str], label: Optional[str] = None):
if node_id is None or not isinstance(node_id, str) or node_id == "":
raise ValueError(f"Invalid node_id {node_id}")
if node_data is None or not isinstance(node_data, dict):
Expand All @@ -798,7 +862,7 @@ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise ValueError(f"Invalid node_data {node_data}")

from uuid import uuid4
label = self.INIT_VERTEX_TYPE
label = label or self.INIT_VERTEX_TYPE

prop_all_names = list(node_data.keys())
prop_name = ",".join(
Expand Down Expand Up @@ -826,7 +890,7 @@ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
if not result.is_succeeded():
raise RuntimeError(f"Failed to upsert node {node_id}: {result} with query {query}")

async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str], label: Optional[str] = None):
if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "":
raise ValueError(f"Invalid source_node_id {source_node_id}")
if target_node_id is None or not isinstance(target_node_id, str) or target_node_id == "":
Expand All @@ -837,7 +901,7 @@ async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data:
raise ValueError(f"Invalid edge_data {edge_data}")

from uuid import uuid4
label = self.INIT_EDGE_TYPE
label = label or self.INIT_EDGE_TYPE

prop_all_names = list(edge_data.keys())
prop_name = ",".join(
Expand Down Expand Up @@ -995,3 +1059,53 @@ async def _node2vec_embed(self):
async def index_done_callback(self):
# TODO: introduce cache mechnism, then we could leverage this callback
pass



@dataclass
class StorageProfile:
full_docs: Callable
text_chunks: Callable
llm_response_cache: Callable
community_reports: Callable
chunk_entity_relation: Callable
entities: Callable

@dataclass
class StorageFactory:
STORAGE_PROFILES: Dict[str, StorageProfile] = {
"local": StorageProfile(
full_docs=JsonKVStorage,
text_chunks=JsonKVStorage,
llm_response_cache=JsonKVStorage,
community_reports=JsonKVStorage,
chunk_entity_relation=NetworkXStorage,
entities=MilvusLiteStorge
),
"nebulagraph": StorageProfile(
full_docs=NebulaGraphIndexStorage,
text_chunks=NebulaGraphIndexStorage,
llm_response_cache=JsonKVStorage,
community_reports=NebulaGraphIndexStorage,
chunk_entity_relation=NebulaGraphStorage,
entities=MilvusLiteStorge
)
}

@staticmethod
def get_storage(
namespace: str,
global_config: Dict[str, Any],
knowledge_store: Literal["local", "nebulagraph"] = "local",
**kwargs
) -> Union[BaseKVStorage, BaseGraphStorage, BaseVectorStorage]:
if knowledge_store not in StorageFactory.STORAGE_PROFILES:
raise ValueError(f"Unsupported knowledge_store: {knowledge_store}")

profile = StorageFactory.STORAGE_PROFILES[knowledge_store]

storage_class = getattr(profile, namespace, None)
if storage_class is None:
raise ValueError(f"Unsupported namespace: {namespace}")

return storage_class(namespace=namespace, global_config=global_config, **kwargs)
7 changes: 7 additions & 0 deletions nano_graphrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ async def upsert(self, data: dict[str, dict]):

@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
namespace: Literal[
"full_docs", # for full docs storage
"text_chunks", # for text chunks storage
"llm_response_cache", # for llm response cache
"community_reports", # for community reports
]

async def all_keys(self) -> list[str]:
raise NotImplementedError

Expand Down
36 changes: 24 additions & 12 deletions nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Literal


from ._llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
Expand All @@ -18,6 +18,7 @@
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
StorageFactory,
)
from ._utils import EmbeddingFunc, compute_mdhash_id, limit_async_func_call, logger
from .base import (
Expand Down Expand Up @@ -100,36 +101,47 @@ def __post_init__(self):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)

self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
self.full_docs = StorageFactory.get_storage(
namespace="full_docs",
global_config=asdict(self),
knowledge_store=self.knowledge_store
)

self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
self.text_chunks = StorageFactory.get_storage(
namespace="text_chunks",
global_config=asdict(self),
knowledge_store=self.knowledge_store
)

self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
StorageFactory.get_storage(
namespace="llm_response_cache",
global_config=asdict(self),
knowledge_store=self.knowledge_store
)
if self.enable_llm_cache
else None
)

self.community_reports = self.key_string_value_json_storage_cls(
namespace="community_reports", global_config=asdict(self)
self.community_reports = StorageFactory.get_storage(
namespace="community_reports",
global_config=asdict(self),
knowledge_store=self.knowledge_store
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
self.chunk_entity_relation_graph = StorageFactory.get_storage(
namespace="chunk_entity_relation",
global_config=asdict(self),
knowledge_store=self.knowledge_store
)

self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
StorageFactory.get_storage(
namespace="entities",
global_config=asdict(self),
knowledge_store=self.knowledge_store,
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
Expand Down

0 comments on commit 7ff213a

Please sign in to comment.