From 340d7a224f51b259cbca9184131dc337b08ef59d Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 25 Nov 2024 20:46:35 +0000 Subject: [PATCH] Feature/rework kg core (#171) * Knowledge cores with msgpack * Put it in the cli package * Tidy up msgpack dumper * Created a loader --- test-api/test-embeddings-api | 3 + trustgraph-cli/scripts/tg-dump-msgpack | 34 ++ trustgraph-cli/scripts/tg-load-kg-core | 179 +++++++++++ trustgraph-cli/scripts/tg-save-kg-core | 190 +++++++++++ trustgraph-cli/setup.py | 4 + .../trustgraph/api/gateway/service.py | 294 +++++++++++++++++- 6 files changed, 700 insertions(+), 4 deletions(-) create mode 100755 trustgraph-cli/scripts/tg-dump-msgpack create mode 100755 trustgraph-cli/scripts/tg-load-kg-core create mode 100755 trustgraph-cli/scripts/tg-save-kg-core diff --git a/test-api/test-embeddings-api b/test-api/test-embeddings-api index ef9ea099..b1defd01 100755 --- a/test-api/test-embeddings-api +++ b/test-api/test-embeddings-api @@ -23,3 +23,6 @@ if "error" in resp: print(f"Error: {resp['error']}") sys.exit(1) +print(resp["vectors"]) + + diff --git a/trustgraph-cli/scripts/tg-dump-msgpack b/trustgraph-cli/scripts/tg-dump-msgpack new file mode 100755 index 00000000..9f91394f --- /dev/null +++ b/trustgraph-cli/scripts/tg-dump-msgpack @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +import msgpack +import sys +import argparse + +def run(input_file): + + with open(input_file, 'rb') as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + for unpacked in unpacker: + print(unpacked) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-load-pdf', + description=__doc__, + ) + + parser.add_argument( + '-i', '--input-file', + required=True, + help=f'Input file' + ) + + args = parser.parse_args() + + run(**vars(args)) + +main() + diff --git a/trustgraph-cli/scripts/tg-load-kg-core b/trustgraph-cli/scripts/tg-load-kg-core new file mode 100755 index 00000000..2469772d --- /dev/null +++ b/trustgraph-cli/scripts/tg-load-kg-core @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os + +async def load_ge(queue, url): + + async with aiohttp.ClientSession() as session: + + async with session.ws_connect(f"{url}load/graph-embeddings") as ws: + + while True: + + msg = await queue.get() + + msg = { + "metadata": { + "id": msg["m"]["i"], + "metadata": msg["m"]["m"], + "user": msg["m"]["u"], + "collection": msg["m"]["c"], + }, + "vectors": msg["v"], + "entity": msg["e"], + } + + await ws.send_json(msg) + +async def load_triples(queue, url): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(f"{url}load/triples") as ws: + + while True: + + msg = await queue.get() + + msg ={ + "metadata": { + "id": msg["m"]["i"], + "metadata": msg["m"]["m"], + "user": msg["m"]["u"], + "collection": msg["m"]["c"], + }, + "triples": msg["t"], + } + + await ws.send_json(msg) + +ge_counts = 0 +t_counts = 0 + +async def stats(): + + global t_counts + global ge_counts + + while True: + await asyncio.sleep(5) + print( + f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}" + ) + +async def loader(ge_queue, t_queue, path, format, user, collection): + + global t_counts + global ge_counts + + if format == "json": + + raise RuntimeError("Not implemented") + + else: + + with open(path, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + for unpacked in unpacker: + + if user: + unpacked["metadata"]["user"] = user + + if collection: + unpacked["metadata"]["collection"] = collection + + + if unpacked[0] == "t": + await t_queue.put(unpacked[1]) + t_counts += 1 + else: + if unpacked[0] == "ge": + await ge_queue.put(unpacked[1]) + ge_counts += 1 + +async def run(**args): + + ge_q = asyncio.Queue() + t_q = asyncio.Queue() + + load_task = asyncio.create_task( + loader( + ge_queue=ge_q, t_queue=t_q, + path=args["input_file"], format=args["format"], + user=args["user"], collection=args["collection"], + ) + + ) + + ge_task = asyncio.create_task( + load_ge( + queue=ge_q, url=args["url"] + "api/v1/" + ) + ) + + triples_task = asyncio.create_task( + load_triples( + queue=t_q, url=args["url"] + "api/v1/" + ) + ) + + stats_task = asyncio.create_task(stats()) + + await load_task + await triples_task + await ge_task + await stats_task + +async def main(): + + parser = argparse.ArgumentParser( + prog='tg-load-pdf', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-i', '--input-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to load as (default: from input)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to load as (default: from input)' + ) + + args = parser.parse_args() + + await run(**vars(args)) + +asyncio.run(main()) + diff --git a/trustgraph-cli/scripts/tg-save-kg-core b/trustgraph-cli/scripts/tg-save-kg-core new file mode 100755 index 00000000..feeea1ef --- /dev/null +++ b/trustgraph-cli/scripts/tg-save-kg-core @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os + +async def fetch_ge(queue, user, collection, url): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(f"{url}stream/graph-embeddings") as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + + data = msg.json() + + if user: + if data["metadata"]["user"] != user: + continue + + if collection: + if data["metadata"]["collection"] != collection: + continue + + await queue.put([ + "ge", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "v": data["vectors"], + "e": data["entity"], + } + ]) + if msg.type == aiohttp.WSMsgType.ERROR: + print("Error") + break + +async def fetch_triples(queue, user, collection, url): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(f"{url}stream/triples") as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + + data = msg.json() + + if user: + if data["metadata"]["user"] != user: + continue + + if collection: + if data["metadata"]["collection"] != collection: + continue + + await queue.put(( + "t", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "t": data["triples"], + } + )) + if msg.type == aiohttp.WSMsgType.ERROR: + print("Error") + break + +ge_counts = 0 +t_counts = 0 + +async def stats(): + + global t_counts + global ge_counts + + while True: + await asyncio.sleep(5) + print( + f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}" + ) + +async def output(queue, path, format): + + global t_counts + global ge_counts + + with open(path, "wb") as f: + + while True: + + msg = await queue.get() + + if format == "msgpack": + f.write(msgpack.packb(msg, use_bin_type=True)) + else: + f.write(json.dumps(msg).encode("utf-8")) + + if msg[0] == "t": + t_counts += 1 + else: + if msg[0] == "ge": + ge_counts += 1 + +async def run(**args): + + q = asyncio.Queue() + + ge_task = asyncio.create_task( + fetch_ge( + queue=q, user=args["user"], collection=args["collection"], + url=args["url"] + "api/v1/" + ) + ) + + triples_task = asyncio.create_task( + fetch_triples( + queue=q, user=args["user"], collection=args["collection"], + url=args["url"] + "api/v1/" + ) + ) + + output_task = asyncio.create_task( + output( + queue=q, path=args["output_file"], format=args["format"], + ) + + ) + + stats_task = asyncio.create_task(stats()) + + await output_task + await triples_task + await ge_task + await stats_task + +async def main(): + + parser = argparse.ArgumentParser( + prog='tg-load-pdf', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-o', '--output-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to filter on (default: no filter)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to filter on (default: no filter)' + ) + + args = parser.parse_args() + + await run(**vars(args)) + +asyncio.run(main()) + diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py index ec541c8b..1608cfdb 100644 --- a/trustgraph-cli/setup.py +++ b/trustgraph-cli/setup.py @@ -39,6 +39,7 @@ "pulsar-client", "rdflib", "tabulate", + "msgpack", ], scripts=[ "scripts/tg-graph-show", @@ -54,5 +55,8 @@ "scripts/tg-invoke-agent", "scripts/tg-invoke-prompt", "scripts/tg-invoke-llm", + "scripts/tg-save-kg-core", + "scripts/tg-load-kg-core", + "scripts/tg-dump-msgpack", ] ) diff --git a/trustgraph-flow/trustgraph/api/gateway/service.py b/trustgraph-flow/trustgraph/api/gateway/service.py index 148bc321..6d5f70ce 100755 --- a/trustgraph-flow/trustgraph/api/gateway/service.py +++ b/trustgraph-flow/trustgraph/api/gateway/service.py @@ -14,7 +14,7 @@ import asyncio import argparse -from aiohttp import web +from aiohttp import web, WSMsgType import json import logging import uuid @@ -47,9 +47,13 @@ from ... schema import graph_rag_request_queue from ... schema import graph_rag_response_queue -from ... schema import TriplesQueryRequest, TriplesQueryResponse +from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples from ... schema import triples_request_queue from ... schema import triples_response_queue +from ... schema import triples_store_queue + +from ... schema import GraphEmbeddings +from ... schema import graph_embeddings_store_queue from ... schema import AgentRequest, AgentResponse from ... schema import agent_request_queue @@ -84,6 +88,11 @@ def to_subgraph(x): for t in x ] +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False + class Publisher: def __init__(self, pulsar_host, topic, schema=None, max_size=10, @@ -132,6 +141,7 @@ def __init__(self, pulsar_host, topic, subscription, consumer_name, self.consumer_name = consumer_name self.schema = schema self.q = {} + self.full = {} async def run(self): while True: @@ -145,10 +155,19 @@ async def run(self): ) as consumer: while True: msg = await consumer.receive() - id = msg.properties()["id"] + + try: + id = msg.properties()["id"] + except: + id = None + value = msg.value() if id in self.q: await self.q[id].put(value) + + for q in self.full.values(): + await q.put(value) + except Exception as e: print("Exception:", e, flush=True) @@ -164,6 +183,59 @@ async def unsubscribe(self, id): if id in self.q: del self.q[id] + async def subscribe_all(self, id): + q = asyncio.Queue() + self.full[id] = q + return q + + async def unsubscribe_all(self, id): + if id in self.full: + del self.full[id] + +def serialize_triples(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": [ + { + "s": t.s.value, + "p": t.p.value, + "o": t.o.value, + } + for t in message.metadata.metadata + ], + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "triples": [ + { + "s": t.s.value, + "p": t.p.value, + "o": t.o.value, + } + for t in message.triples + ] + } + +def serialize_graph_embeddings(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": [ + { + "s": t.s.value, + "p": t.p.value, + "o": t.o.value, + } + for t in message.metadata.metadata + ], + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "vectors": message.vectors, + "entity": message.entity.value, + } + class Api: def __init__(self, **config): @@ -243,6 +315,28 @@ def __init__(self, **config): JsonSchema(EmbeddingsResponse) ) + self.triples_tap = Subscriber( + self.pulsar_host, triples_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(Triples) + ) + + self.triples_pub = Publisher( + self.pulsar_host, triples_store_queue, + schema=JsonSchema(Triples) + ) + + self.graph_embeddings_tap = Subscriber( + self.pulsar_host, graph_embeddings_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(GraphEmbeddings) + ) + + self.graph_embeddings_pub = Publisher( + self.pulsar_host, graph_embeddings_store_queue, + schema=JsonSchema(GraphEmbeddings) + ) + self.document_out = Publisher( self.pulsar_host, document_ingest_queue, schema=JsonSchema(Document), @@ -264,6 +358,20 @@ def __init__(self, **config): web.post("/api/v1/embeddings", self.embeddings), web.post("/api/v1/load/document", self.load_document), web.post("/api/v1/load/text", self.load_text), + web.get("/api/v1/ws", self.socket), + + web.get("/api/v1/stream/triples", self.stream_triples), + web.get( + "/api/v1/stream/graph-embeddings", + self.stream_graph_embeddings + ), + + web.get("/api/v1/load/triples", self.load_triples), + web.get( + "/api/v1/load/graph-embeddings", + self.load_graph_embeddings + ), + ]) async def llm(self, request): @@ -660,6 +768,169 @@ async def load_text(self, request): { "error": str(e) } ) + async def socket(self, request): + + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == WSMsgType.TEXT: + if msg.data == 'close': + await ws.close() + else: + await ws.send_str(msg.data + '/answer') + elif msg.type == WSMsgType.ERROR: + print('ws connection closed with exception %s' % + ws.exception()) + + print('websocket connection closed') + + return ws + + async def stream(self, q, ws, running, fn): + + while running.get(): + try: + resp = await asyncio.wait_for(q.get(), 0.5) + await ws.send_json(fn(resp)) + + except TimeoutError: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + + async def stream_triples(self, request): + + id = str(uuid.uuid4()) + + q = await self.triples_tap.subscribe_all(id) + running = Running() + + ws = web.WebSocketResponse() + await ws.prepare(request) + + tsk = asyncio.create_task(self.stream( + q, + ws, + running, + serialize_triples, + )) + + async for msg in ws: + if msg.type == WSMsgType.ERROR: + break + else: + # Ignore incoming messages + pass + + running.stop() + + await self.triples_tap.unsubscribe_all(id) + await tsk + + return ws + + async def stream_graph_embeddings(self, request): + + id = str(uuid.uuid4()) + + q = await self.graph_embeddings_tap.subscribe_all(id) + running = Running() + + ws = web.WebSocketResponse() + await ws.prepare(request) + + tsk = asyncio.create_task(self.stream( + q, + ws, + running, + serialize_graph_embeddings, + )) + + async for msg in ws: + if msg.type == WSMsgType.ERROR: + break + else: + # Ignore incoming messages + pass + + running.stop() + + await self.graph_embeddings_tap.unsubscribe_all(id) + await tsk + + return ws + + async def load_triples(self, request): + + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + + try: + + if msg.type == WSMsgType.TEXT: + + data = msg.json() + + elt = Triples( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + triples=to_subgraph(data["triples"]), + ) + + await self.triples_pub.send(None, elt) + + elif msg.type == WSMsgType.ERROR: + break + + except Exception as e: + + print("Exception:", e) + + return ws + + async def load_graph_embeddings(self, request): + + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + + try: + + if msg.type == WSMsgType.TEXT: + + data = msg.json() + + elt = GraphEmbeddings( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + entity=to_value(data["entity"]), + vectors=data["vectors"], + ) + + await self.graph_embeddings_pub.send(None, elt) + + elif msg.type == WSMsgType.ERROR: + break + + except Exception as e: + + print("Exception:", e) + + return ws + async def app_factory(self): self.llm_pub_task = asyncio.create_task(self.llm_in.run()) @@ -688,6 +959,22 @@ async def app_factory(self): self.embeddings_out.run() ) + self.triples_tap_task = asyncio.create_task( + self.triples_tap.run() + ) + + self.triples_pub_task = asyncio.create_task( + self.triples_pub.run() + ) + + self.graph_embeddings_tap_task = asyncio.create_task( + self.graph_embeddings_tap.run() + ) + + self.graph_embeddings_pub_task = asyncio.create_task( + self.graph_embeddings_pub.run() + ) + self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run()) self.text_ingest_pub_task = asyncio.create_task(self.text_out.run()) @@ -699,7 +986,6 @@ def run(self): def run(): - parser = argparse.ArgumentParser( prog="api-gateway", description=__doc__