Skip to content

Commit

Permalink
tests: fix test wrong with clusters node data
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 25, 2024
1 parent 8f5f6ef commit 032d041
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions examples/no_openai_key_at_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def ollama_model_if_cache(
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down
1 change: 1 addition & 0 deletions examples/using_ollama_as_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def ollama_model_if_cache(
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down
17 changes: 10 additions & 7 deletions examples/using_ollama_as_llm_and_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
EMBEDDING_MODEL_DIM = 768
EMBEDDING_MODEL_MAX_TOKENS = 8192


async def ollama_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
Expand Down Expand Up @@ -98,20 +100,21 @@ def insert():
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
# rag.insert(FAKE_TEXT[half_len:])


# We're using Ollama to generate embeddings for the BGE model
@wrap_embedding_func_with_attrs(
embedding_dim= EMBEDDING_MODEL_DIM,
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
embedding_dim=EMBEDDING_MODEL_DIM,
max_token_size=EMBEDDING_MODEL_MAX_TOKENS,
)

async def ollama_embedding(texts :list[str]) -> np.ndarray:
async def ollama_embedding(texts: list[str]) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])

return embed_text


if __name__ == "__main__":
insert()
query()
2 changes: 2 additions & 0 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ async def _find_most_related_community_from_entities(
):
related_communities = []
for node_d in node_datas:
if "clusters" not in node_d:
continue
related_communities.extend(json.loads(node_d["clusters"]))
related_community_dup_keys = [
str(dp["cluster"])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_neo4j_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_neo4j_storage_init():
async def test_upsert_and_get_node(neo4j_storage):
node_id = "node1"
node_data = {"attr1": "value1", "attr2": "value2"}
return_data = {"id": node_id, **node_data}
return_data = {"id": node_id, "clusters": "[]", **node_data}

await neo4j_storage.upsert_node(node_id, node_data)

Expand Down Expand Up @@ -190,7 +190,7 @@ async def test_nonexistent_node_and_edge(neo4j_storage):
assert await neo4j_storage.has_edge("node1", "node2") is False
assert await neo4j_storage.get_node("nonexistent") is None
assert await neo4j_storage.get_edge("node1", "node2") is None
assert await neo4j_storage.get_node_edges("nonexistent") is None
assert await neo4j_storage.get_node_edges("nonexistent") == []
assert await neo4j_storage.node_degree("nonexistent") == 0
assert await neo4j_storage.edge_degree("node1", "node2") == 0

Expand Down
10 changes: 6 additions & 4 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

@pytest.fixture
def mock_openai_client():
with patch("nano_graphrag._llm.AsyncOpenAI") as mock_openai:
with patch("nano_graphrag._llm.get_openai_async_client_instance") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
yield mock_client


@pytest.fixture
def mock_azure_openai_client():
with patch("nano_graphrag._llm.AsyncAzureOpenAI") as mock_openai:
with patch(
"nano_graphrag._llm.get_azure_openai_async_client_instance"
) as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
yield mock_client
Expand All @@ -37,7 +39,7 @@ async def test_openai_gpt4o(mock_openai_client):


@pytest.mark.asyncio
async def test_openai_gpt4o_mini(mock_openai_client):
async def test_openai_gpt4omini(mock_openai_client):
mock_response = AsyncMock()
mock_response.choices = [Mock(message=Mock(content="1"))]
messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}]
Expand Down Expand Up @@ -69,7 +71,7 @@ async def test_azure_openai_gpt4o(mock_azure_openai_client):


@pytest.mark.asyncio
async def test_azure_openai_gpt4o_mini(mock_azure_openai_client):
async def test_azure_openai_gpt4omini(mock_azure_openai_client):
mock_response = AsyncMock()
mock_response.choices = [Mock(message=Mock(content="1"))]
messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}]
Expand Down

0 comments on commit 032d041

Please sign in to comment.