diff --git a/examples/using_amazon_bedrock.py b/examples/using_amazon_bedrock.py new file mode 100644 index 0000000..c8aeac4 --- /dev/null +++ b/examples/using_amazon_bedrock.py @@ -0,0 +1,19 @@ +from nano_graphrag import GraphRAG, QueryParam + +graph_func = GraphRAG( + working_dir="../bedrock_example", + using_amazon_bedrock=True, + best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", + cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0", +) + +with open("../tests/mock_data.txt") as f: + graph_func.insert(f.read()) + +prompt = "What are the top themes in this story?" + +# Perform global graphrag search +print(graph_func.query(prompt, param=QueryParam(mode="global"))) + +# Perform local graphrag search (I think is better and more scalable one) +print(graph_func.query(prompt, param=QueryParam(mode="local"))) diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index f658234..974c339 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -1,5 +1,8 @@ +import json import numpy as np +from typing import Optional, List, Any, Callable +import aioboto3 from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError from tenacity import ( @@ -15,6 +18,7 @@ global_openai_async_client = None global_azure_openai_async_client = None +global_amazon_bedrock_async_client = None def get_openai_async_client_instance(): @@ -31,6 +35,13 @@ def get_azure_openai_async_client_instance(): return global_azure_openai_async_client +def get_amazon_bedrock_async_client_instance(): + global global_amazon_bedrock_async_client + if global_amazon_bedrock_async_client is None: + global_amazon_bedrock_async_client = aioboto3.Session() + return global_amazon_bedrock_async_client + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -64,6 +75,82 @@ async def openai_complete_if_cache( return response.choices[0].message.content +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def amazon_bedrock_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + messages.extend(history_messages) + messages.append({"role": "user", "content": [{"text": prompt}]}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + inference_config = { + "temperature": 0, + "maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"], + } + + async with amazon_bedrock_async_client.client( + "bedrock-runtime", + region_name=os.getenv("AWS_REGION", "us-east-1") + ) as bedrock_runtime: + if system_prompt: + response = await bedrock_runtime.converse( + modelId=model, messages=messages, inferenceConfig=inference_config, + system=[{"text": system_prompt}] + ) + else: + response = await bedrock_runtime.converse( + modelId=model, messages=messages, inferenceConfig=inference_config, + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}} + ) + await hashing_kv.index_done_callback() + return response["output"]["message"]["content"][0]["text"] + + +def create_amazon_bedrock_complete_function(model_id: str) -> Callable: + """ + Factory function to dynamically create completion functions for Amazon Bedrock + + Args: + model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0") + + Returns: + Callable: Generated completion function + """ + async def bedrock_complete( + prompt: str, + system_prompt: Optional[str] = None, + history_messages: List[Any] = [], + **kwargs + ) -> str: + return await amazon_bedrock_complete_if_cache( + model_id, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + # Set function name for easier debugging + bedrock_complete.__name__ = f"{model_id}_complete" + + return bedrock_complete + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -88,6 +175,35 @@ async def gpt_4o_mini_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray: + amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance() + + async with amazon_bedrock_async_client.client( + "bedrock-runtime", + region_name=os.getenv("AWS_REGION", "us-east-1") + ) as bedrock_runtime: + embeddings = [] + for text in texts: + body = json.dumps( + { + "inputText": text, + "dimensions": 1024, + } + ) + response = await bedrock_runtime.invoke_model( + modelId="amazon.titan-embed-text-v2:0", body=body, + ) + response_body = await response.get("body").read() + embeddings.append(json.loads(response_body)) + return np.array([dp["embedding"] for dp in embeddings]) + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(5), diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index 7b5b9a8..2867878 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -293,6 +293,7 @@ async def extract_entities( knwoledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, global_config: dict, + using_amazon_bedrock: bool=False, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["best_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -320,12 +321,14 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): content = chunk_dp["content"] hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) final_result = await use_llm_func(hint_prompt) + if isinstance(final_result, list): + final_result = final_result[0]["text"] - history = pack_user_ass_to_openai_messages(hint_prompt, final_result) + history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock) for now_glean_index in range(entity_extract_max_gleaning): glean_result = await use_llm_func(continue_prompt, history_messages=history) - history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) + history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock) final_result += glean_result if now_glean_index == entity_extract_max_gleaning - 1: break diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index ae772eb..8f76227 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -162,11 +162,17 @@ def load_json(file_name): # it's dirty to type, so it's a good way to have fun -def pack_user_ass_to_openai_messages(*args: str): - roles = ["user", "assistant"] - return [ - {"role": roles[i % 2], "content": content} for i, content in enumerate(args) - ] +def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool): + if using_amazon_bedrock: + return [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": generated_content}]}, + ] + else: + return [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": generated_content}, + ] def is_float_regex(value): diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 2c9e1be..e60fcb0 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -9,6 +9,8 @@ from ._llm import ( + amazon_bedrock_embedding, + create_amazon_bedrock_complete_function, gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, @@ -107,6 +109,9 @@ class GraphRAG: # LLM using_azure_openai: bool = False + using_amazon_bedrock: bool = False + best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0" + cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0" best_model_func: callable = gpt_4o_complete best_model_max_token_size: int = 32768 best_model_max_async: int = 16 @@ -145,6 +150,14 @@ def __post_init__(self): "Switched the default openai funcs to Azure OpenAI if you didn't set any of it" ) + if self.using_amazon_bedrock: + self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id) + self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id) + self.embedding_func = amazon_bedrock_embedding + logger.info( + "Switched the default openai funcs to Amazon Bedrock" + ) + if not os.path.exists(self.working_dir) and self.always_create_working_dir: logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -298,6 +311,7 @@ async def ainsert(self, string_or_strings): knwoledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, global_config=asdict(self), + using_amazon_bedrock=self.using_amazon_bedrock, ) if maybe_new_kg is None: logger.warning("No new entities found") diff --git a/readme.md b/readme.md index 112c4a0..2fc470c 100644 --- a/readme.md +++ b/readme.md @@ -73,6 +73,9 @@ pip install nano-graphrag > [!TIP] > If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable. +> [!TIP] +> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure`. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py). + > [!TIP] > > If you don't have any key, check out this [example](./examples/no_openai_key_at_all.py) that using `transformers` and `ollama` . If you like to use another LLM or Embedding Model, check [Advances](#Advances). @@ -167,9 +170,11 @@ Below are the components you can use: | Type | What | Where | | :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: | | LLM | OpenAI | Built-in | +| | Amazon Bedrock | Built-in | | | DeepSeek | [examples](./examples) | | | `ollama` | [examples](./examples) | | Embedding | OpenAI | Built-in | +| | Amazon Bedrock | Built-in | | | Sentence-transformers | [examples](./examples) | | Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in | | | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) | diff --git a/requirements.txt b/requirements.txt index be0e993..7d26a49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ hnswlib xxhash tenacity dspy-ai -neo4j \ No newline at end of file +neo4j +aioboto3 \ No newline at end of file