forked from gusye1234/nano-graphrag
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from gusye1234/main
MERGE Master
- Loading branch information
Showing
25 changed files
with
17,315 additions
and
917 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14,525 changes: 14,146 additions & 379 deletions
14,525
examples/finetune_entity_relationship_dspy.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,43 @@ | ||
|
||
|
||
from nano_graphrag._utils import encode_string_by_tiktoken | ||
from nano_graphrag.base import QueryParam | ||
from nano_graphrag.graphrag import GraphRAG | ||
from nano_graphrag._op import chunking_by_seperators | ||
|
||
|
||
def chunking_by_specific_separators( | ||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o", | ||
def chunking_by_token_size( | ||
tokens_list: list[list[int]], # nano-graphrag may pass a batch of docs' tokens | ||
doc_keys: list[str], # nano-graphrag may pass a batch of docs' key ids | ||
tiktoken_model, # a titoken model | ||
overlap_token_size=128, | ||
max_token_size=1024, | ||
): | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
|
||
|
||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=max_token_size, | ||
chunk_overlap=overlap_token_size, | ||
# length_function=lambda x: len(encode_string_by_tiktoken(x)), | ||
model_name=tiktoken_model, | ||
is_separator_regex=False, | ||
separators=[ | ||
# Paragraph separators | ||
"\n\n", | ||
"\r\n\r\n", | ||
# Line breaks | ||
"\n", | ||
"\r\n", | ||
# Sentence ending punctuation | ||
"。", # Chinese period | ||
".", # Full-width dot | ||
".", # English period | ||
"!", # Chinese exclamation mark | ||
"!", # English exclamation mark | ||
"?", # Chinese question mark | ||
"?", # English question mark | ||
# Whitespace characters | ||
" ", # Space | ||
"\t", # Tab | ||
"\u3000", # Full-width space | ||
# Special characters | ||
"\u200b", # Zero-width space (used in some Asian languages) | ||
# Final fallback | ||
"", | ||
]) | ||
texts = text_splitter.split_text(content) | ||
|
||
results = [] | ||
for index, chunk_content in enumerate(texts): | ||
|
||
results.append( | ||
{ | ||
# "tokens": None, | ||
"content": chunk_content.strip(), | ||
"chunk_order_index": index, | ||
} | ||
) | ||
for index, tokens in enumerate(tokens_list): | ||
chunk_token = [] | ||
lengths = [] | ||
for start in range(0, len(tokens), max_token_size - overlap_token_size): | ||
|
||
chunk_token.append(tokens[start : start + max_token_size]) | ||
lengths.append(min(max_token_size, len(tokens) - start)) | ||
|
||
chunk_token = tiktoken_model.decode_batch(chunk_token) | ||
for i, chunk in enumerate(chunk_token): | ||
|
||
results.append( | ||
{ | ||
"tokens": lengths[i], | ||
"content": chunk.strip(), | ||
"chunk_order_index": i, | ||
"full_doc_id": doc_keys[index], | ||
} | ||
) | ||
|
||
return results | ||
|
||
|
||
WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST" | ||
rag = GraphRAG( | ||
working_dir=WORKING_DIR, | ||
chunk_func=chunking_by_specific_separators, | ||
chunk_func=chunking_by_seperators, | ||
) | ||
|
||
with open("../tests/mock_data.txt", encoding="utf-8-sig") as f: | ||
FAKE_TEXT = f.read() | ||
|
||
# rag.insert(FAKE_TEXT) | ||
print(rag.query("What the main theme of this story?", param=QueryParam(mode="local"))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import os | ||
import logging | ||
import ollama | ||
import numpy as np | ||
from openai import AsyncOpenAI | ||
from nano_graphrag import GraphRAG, QueryParam | ||
from nano_graphrag import GraphRAG, QueryParam | ||
from nano_graphrag.base import BaseKVStorage | ||
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs | ||
|
||
logging.basicConfig(level=logging.WARNING) | ||
logging.getLogger("nano-graphrag").setLevel(logging.INFO) | ||
|
||
# Assumed llm model settings | ||
LLM_BASE_URL = "https://your.api.url" | ||
LLM_API_KEY = "your_api_key" | ||
MODEL = "your_model_name" | ||
|
||
# Assumed embedding model settings | ||
EMBEDDING_MODEL = "nomic-embed-text" | ||
EMBEDDING_MODEL_DIM = 768 | ||
EMBEDDING_MODEL_MAX_TOKENS = 8192 | ||
|
||
|
||
async def llm_model_if_cache( | ||
prompt, system_prompt=None, history_messages=[], **kwargs | ||
) -> str: | ||
openai_async_client = AsyncOpenAI( | ||
api_key=LLM_API_KEY, base_url=LLM_BASE_URL | ||
) | ||
messages = [] | ||
if system_prompt: | ||
messages.append({"role": "system", "content": system_prompt}) | ||
|
||
# Get the cached response if having------------------- | ||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||
messages.extend(history_messages) | ||
messages.append({"role": "user", "content": prompt}) | ||
if hashing_kv is not None: | ||
args_hash = compute_args_hash(MODEL, messages) | ||
if_cache_return = await hashing_kv.get_by_id(args_hash) | ||
if if_cache_return is not None: | ||
return if_cache_return["return"] | ||
# ----------------------------------------------------- | ||
|
||
response = await openai_async_client.chat.completions.create( | ||
model=MODEL, messages=messages, **kwargs | ||
) | ||
|
||
# Cache the response if having------------------- | ||
if hashing_kv is not None: | ||
await hashing_kv.upsert( | ||
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}} | ||
) | ||
# ----------------------------------------------------- | ||
return response.choices[0].message.content | ||
|
||
|
||
def remove_if_exist(file): | ||
if os.path.exists(file): | ||
os.remove(file) | ||
|
||
|
||
WORKING_DIR = "./nano_graphrag_cache_llm_TEST" | ||
|
||
|
||
def query(): | ||
rag = GraphRAG( | ||
working_dir=WORKING_DIR, | ||
best_model_func=llm_model_if_cache, | ||
cheap_model_func=llm_model_if_cache, | ||
embedding_func=ollama_embedding, | ||
) | ||
print( | ||
rag.query( | ||
"What are the top themes in this story?", param=QueryParam(mode="global") | ||
) | ||
) | ||
|
||
|
||
def insert(): | ||
from time import time | ||
|
||
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: | ||
FAKE_TEXT = f.read() | ||
|
||
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") | ||
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") | ||
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") | ||
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") | ||
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") | ||
|
||
rag = GraphRAG( | ||
working_dir=WORKING_DIR, | ||
enable_llm_cache=True, | ||
best_model_func=llm_model_if_cache, | ||
cheap_model_func=llm_model_if_cache, | ||
embedding_func=ollama_embedding, | ||
) | ||
start = time() | ||
rag.insert(FAKE_TEXT) | ||
print("indexing time:", time() - start) | ||
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True) | ||
# rag.insert(FAKE_TEXT[half_len:]) | ||
|
||
# We're using Ollama to generate embeddings for the BGE model | ||
@wrap_embedding_func_with_attrs( | ||
embedding_dim= EMBEDDING_MODEL_DIM, | ||
max_token_size= EMBEDDING_MODEL_MAX_TOKENS, | ||
) | ||
|
||
async def ollama_embedding(texts :list[str]) -> np.ndarray: | ||
embed_text = [] | ||
for text in texts: | ||
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text) | ||
embed_text.append(data["embedding"]) | ||
|
||
return embed_text | ||
|
||
if __name__ == "__main__": | ||
insert() | ||
query() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.