Skip to content

Commit

Permalink
Feat: Add Amazon Bedrock support (#97)
Browse files Browse the repository at this point in the history
* Add Amazon Bedrock support

* add sample script to test amazon bedrock integration

* add the latest Claude 3.5 Sonnet v1&v2 model

* Add a factory function for bedrock completion instead of creating one for each model

* update README.md to explain the Bedrock option.

* clean up
  • Loading branch information
kmotohas authored Nov 23, 2024
1 parent a8043a6 commit 18fa3a4
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 8 deletions.
19 changes: 19 additions & 0 deletions examples/using_amazon_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from nano_graphrag import GraphRAG, QueryParam

graph_func = GraphRAG(
working_dir="../bedrock_example",
using_amazon_bedrock=True,
best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0",
cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",
)

with open("../tests/mock_data.txt") as f:
graph_func.insert(f.read())

prompt = "What are the top themes in this story?"

# Perform global graphrag search
print(graph_func.query(prompt, param=QueryParam(mode="global")))

# Perform local graphrag search (I think is better and more scalable one)
print(graph_func.query(prompt, param=QueryParam(mode="local")))
116 changes: 116 additions & 0 deletions nano_graphrag/_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import numpy as np
from typing import Optional, List, Any, Callable

import aioboto3
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError

from tenacity import (
Expand All @@ -15,6 +18,7 @@

global_openai_async_client = None
global_azure_openai_async_client = None
global_amazon_bedrock_async_client = None


def get_openai_async_client_instance():
Expand All @@ -31,6 +35,13 @@ def get_azure_openai_async_client_instance():
return global_azure_openai_async_client


def get_amazon_bedrock_async_client_instance():
global global_amazon_bedrock_async_client
if global_amazon_bedrock_async_client is None:
global_amazon_bedrock_async_client = aioboto3.Session()
return global_amazon_bedrock_async_client


@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
Expand Down Expand Up @@ -64,6 +75,82 @@ async def openai_complete_if_cache(
return response.choices[0].message.content


@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
messages.extend(history_messages)
messages.append({"role": "user", "content": [{"text": prompt}]})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]

inference_config = {
"temperature": 0,
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
}

async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
if system_prompt:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
system=[{"text": system_prompt}]
)
else:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
)

if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
)
await hashing_kv.index_done_callback()
return response["output"]["message"]["content"][0]["text"]


def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
"""
Factory function to dynamically create completion functions for Amazon Bedrock
Args:
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
Returns:
Callable: Generated completion function
"""
async def bedrock_complete(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[Any] = [],
**kwargs
) -> str:
return await amazon_bedrock_complete_if_cache(
model_id,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)

# Set function name for easier debugging
bedrock_complete.__name__ = f"{model_id}_complete"

return bedrock_complete


async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand All @@ -88,6 +175,35 @@ async def gpt_4o_mini_complete(
)


@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()

async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
embeddings = []
for text in texts:
body = json.dumps(
{
"inputText": text,
"dimensions": 1024,
}
)
response = await bedrock_runtime.invoke_model(
modelId="amazon.titan-embed-text-v2:0", body=body,
)
response_body = await response.get("body").read()
embeddings.append(json.loads(response_body))
return np.array([dp["embedding"] for dp in embeddings])


@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
Expand Down
7 changes: 5 additions & 2 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ async def extract_entities(
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
global_config: dict,
using_amazon_bedrock: bool=False,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
Expand Down Expand Up @@ -320,12 +321,14 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
if isinstance(final_result, list):
final_result = final_result[0]["text"]

history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)

history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
Expand Down
16 changes: 11 additions & 5 deletions nano_graphrag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,17 @@ def load_json(file_name):


# it's dirty to type, so it's a good way to have fun
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
if using_amazon_bedrock:
return [
{"role": "user", "content": [{"text": prompt}]},
{"role": "assistant", "content": [{"text": generated_content}]},
]
else:
return [
{"role": "user", "content": prompt},
{"role": "assistant", "content": generated_content},
]


def is_float_regex(value):
Expand Down
14 changes: 14 additions & 0 deletions nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


from ._llm import (
amazon_bedrock_embedding,
create_amazon_bedrock_complete_function,
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
Expand Down Expand Up @@ -107,6 +109,9 @@ class GraphRAG:

# LLM
using_azure_openai: bool = False
using_amazon_bedrock: bool = False
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
best_model_func: callable = gpt_4o_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
Expand Down Expand Up @@ -145,6 +150,14 @@ def __post_init__(self):
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)

if self.using_amazon_bedrock:
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
self.embedding_func = amazon_bedrock_embedding
logger.info(
"Switched the default openai funcs to Amazon Bedrock"
)

if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
Expand Down Expand Up @@ -298,6 +311,7 @@ async def ainsert(self, string_or_strings):
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
global_config=asdict(self),
using_amazon_bedrock=self.using_amazon_bedrock,
)
if maybe_new_kg is None:
logger.warning("No new entities found")
Expand Down
5 changes: 5 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ pip install nano-graphrag
> [!TIP]
> If you're using Azure OpenAI API, refer to the [.env.example](./.env.example.azure) to set your azure openai. Then pass `GraphRAG(...,using_azure_openai=True,...)` to enable.
> [!TIP]
> If you're using Amazon Bedrock API, please ensure your credentials are properly set through commands like `aws configure`. Then enable it by configuring like this: `GraphRAG(...,using_amazon_bedrock=True, best_model_id="us.anthropic.claude-3-sonnet-20240229-v1:0", cheap_model_id="us.anthropic.claude-3-haiku-20240307-v1:0",...)`. Refer to an [example script](./examples/using_amazon_bedrock.py).
> [!TIP]
>
> If you don't have any key, check out this [example](./examples/no_openai_key_at_all.py) that using `transformers` and `ollama` . If you like to use another LLM or Embedding Model, check [Advances](#Advances).
Expand Down Expand Up @@ -167,9 +170,11 @@ Below are the components you can use:
| Type | What | Where |
| :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: |
| LLM | OpenAI | Built-in |
| | Amazon Bedrock | Built-in |
| | DeepSeek | [examples](./examples) |
| | `ollama` | [examples](./examples) |
| Embedding | OpenAI | Built-in |
| | Amazon Bedrock | Built-in |
| | Sentence-transformers | [examples](./examples) |
| Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in |
| | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) |
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ hnswlib
xxhash
tenacity
dspy-ai
neo4j
neo4j
aioboto3

0 comments on commit 18fa3a4

Please sign in to comment.