From 032d0419671930395082323028f6a1261a9d311c Mon Sep 17 00:00:00 2001 From: Gus Date: Wed, 25 Sep 2024 20:02:09 +0800 Subject: [PATCH] tests: fix test wrong with clusters node data --- examples/no_openai_key_at_all.py | 1 + examples/using_ollama_as_llm.py | 1 + examples/using_ollama_as_llm_and_embedding.py | 17 ++++++++++------- nano_graphrag/_op.py | 2 ++ tests/test_neo4j_storage.py | 4 ++-- tests/test_openai.py | 10 ++++++---- 6 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/no_openai_key_at_all.py b/examples/no_openai_key_at_all.py index f1624e0..1fce788 100644 --- a/examples/no_openai_key_at_all.py +++ b/examples/no_openai_key_at_all.py @@ -34,6 +34,7 @@ async def ollama_model_if_cache( ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] diff --git a/examples/using_ollama_as_llm.py b/examples/using_ollama_as_llm.py index a358b28..e067212 100644 --- a/examples/using_ollama_as_llm.py +++ b/examples/using_ollama_as_llm.py @@ -18,6 +18,7 @@ async def ollama_model_if_cache( ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] diff --git a/examples/using_ollama_as_llm_and_embedding.py b/examples/using_ollama_as_llm_and_embedding.py index 97614b9..44d669d 100644 --- a/examples/using_ollama_as_llm_and_embedding.py +++ b/examples/using_ollama_as_llm_and_embedding.py @@ -20,11 +20,13 @@ EMBEDDING_MODEL_DIM = 768 EMBEDDING_MODEL_MAX_TOKENS = 8192 + async def ollama_model_if_cache( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] @@ -98,20 +100,21 @@ def insert(): # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True) # rag.insert(FAKE_TEXT[half_len:]) + # We're using Ollama to generate embeddings for the BGE model @wrap_embedding_func_with_attrs( - embedding_dim= EMBEDDING_MODEL_DIM, - max_token_size= EMBEDDING_MODEL_MAX_TOKENS, + embedding_dim=EMBEDDING_MODEL_DIM, + max_token_size=EMBEDDING_MODEL_MAX_TOKENS, ) - -async def ollama_embedding(texts :list[str]) -> np.ndarray: +async def ollama_embedding(texts: list[str]) -> np.ndarray: embed_text = [] for text in texts: - data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text) - embed_text.append(data["embedding"]) - + data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text) + embed_text.append(data["embedding"]) + return embed_text + if __name__ == "__main__": insert() query() diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index 12745cd..f06b3c5 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -678,6 +678,8 @@ async def _find_most_related_community_from_entities( ): related_communities = [] for node_d in node_datas: + if "clusters" not in node_d: + continue related_communities.extend(json.loads(node_d["clusters"])) related_community_dup_keys = [ str(dp["cluster"]) diff --git a/tests/test_neo4j_storage.py b/tests/test_neo4j_storage.py index 488e813..8a87c3a 100644 --- a/tests/test_neo4j_storage.py +++ b/tests/test_neo4j_storage.py @@ -66,7 +66,7 @@ def test_neo4j_storage_init(): async def test_upsert_and_get_node(neo4j_storage): node_id = "node1" node_data = {"attr1": "value1", "attr2": "value2"} - return_data = {"id": node_id, **node_data} + return_data = {"id": node_id, "clusters": "[]", **node_data} await neo4j_storage.upsert_node(node_id, node_data) @@ -190,7 +190,7 @@ async def test_nonexistent_node_and_edge(neo4j_storage): assert await neo4j_storage.has_edge("node1", "node2") is False assert await neo4j_storage.get_node("nonexistent") is None assert await neo4j_storage.get_edge("node1", "node2") is None - assert await neo4j_storage.get_node_edges("nonexistent") is None + assert await neo4j_storage.get_node_edges("nonexistent") == [] assert await neo4j_storage.node_degree("nonexistent") == 0 assert await neo4j_storage.edge_degree("node1", "node2") == 0 diff --git a/tests/test_openai.py b/tests/test_openai.py index afc5eea..109fd50 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -6,7 +6,7 @@ @pytest.fixture def mock_openai_client(): - with patch("nano_graphrag._llm.AsyncOpenAI") as mock_openai: + with patch("nano_graphrag._llm.get_openai_async_client_instance") as mock_openai: mock_client = AsyncMock() mock_openai.return_value = mock_client yield mock_client @@ -14,7 +14,9 @@ def mock_openai_client(): @pytest.fixture def mock_azure_openai_client(): - with patch("nano_graphrag._llm.AsyncAzureOpenAI") as mock_openai: + with patch( + "nano_graphrag._llm.get_azure_openai_async_client_instance" + ) as mock_openai: mock_client = AsyncMock() mock_openai.return_value = mock_client yield mock_client @@ -37,7 +39,7 @@ async def test_openai_gpt4o(mock_openai_client): @pytest.mark.asyncio -async def test_openai_gpt4o_mini(mock_openai_client): +async def test_openai_gpt4omini(mock_openai_client): mock_response = AsyncMock() mock_response.choices = [Mock(message=Mock(content="1"))] messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] @@ -69,7 +71,7 @@ async def test_azure_openai_gpt4o(mock_azure_openai_client): @pytest.mark.asyncio -async def test_azure_openai_gpt4o_mini(mock_azure_openai_client): +async def test_azure_openai_gpt4omini(mock_azure_openai_client): mock_response = AsyncMock() mock_response.choices = [Mock(message=Mock(content="1"))] messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}]