-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Knowledge cores with msgpack * Put it in the cli package * Tidy up msgpack dumper * Created a loader
- Loading branch information
1 parent
319f9ac
commit 340d7a2
Showing
6 changed files
with
700 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,6 @@ if "error" in resp: | |
print(f"Error: {resp['error']}") | ||
sys.exit(1) | ||
|
||
print(resp["vectors"]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import msgpack | ||
import sys | ||
import argparse | ||
|
||
def run(input_file): | ||
|
||
with open(input_file, 'rb') as f: | ||
|
||
unpacker = msgpack.Unpacker(f, raw=False) | ||
|
||
for unpacked in unpacker: | ||
print(unpacked) | ||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser( | ||
prog='tg-load-pdf', | ||
description=__doc__, | ||
) | ||
|
||
parser.add_argument( | ||
'-i', '--input-file', | ||
required=True, | ||
help=f'Input file' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
run(**vars(args)) | ||
|
||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import aiohttp | ||
import asyncio | ||
import msgpack | ||
import json | ||
import sys | ||
import argparse | ||
import os | ||
|
||
async def load_ge(queue, url): | ||
|
||
async with aiohttp.ClientSession() as session: | ||
|
||
async with session.ws_connect(f"{url}load/graph-embeddings") as ws: | ||
|
||
while True: | ||
|
||
msg = await queue.get() | ||
|
||
msg = { | ||
"metadata": { | ||
"id": msg["m"]["i"], | ||
"metadata": msg["m"]["m"], | ||
"user": msg["m"]["u"], | ||
"collection": msg["m"]["c"], | ||
}, | ||
"vectors": msg["v"], | ||
"entity": msg["e"], | ||
} | ||
|
||
await ws.send_json(msg) | ||
|
||
async def load_triples(queue, url): | ||
async with aiohttp.ClientSession() as session: | ||
async with session.ws_connect(f"{url}load/triples") as ws: | ||
|
||
while True: | ||
|
||
msg = await queue.get() | ||
|
||
msg ={ | ||
"metadata": { | ||
"id": msg["m"]["i"], | ||
"metadata": msg["m"]["m"], | ||
"user": msg["m"]["u"], | ||
"collection": msg["m"]["c"], | ||
}, | ||
"triples": msg["t"], | ||
} | ||
|
||
await ws.send_json(msg) | ||
|
||
ge_counts = 0 | ||
t_counts = 0 | ||
|
||
async def stats(): | ||
|
||
global t_counts | ||
global ge_counts | ||
|
||
while True: | ||
await asyncio.sleep(5) | ||
print( | ||
f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}" | ||
) | ||
|
||
async def loader(ge_queue, t_queue, path, format, user, collection): | ||
|
||
global t_counts | ||
global ge_counts | ||
|
||
if format == "json": | ||
|
||
raise RuntimeError("Not implemented") | ||
|
||
else: | ||
|
||
with open(path, "rb") as f: | ||
|
||
unpacker = msgpack.Unpacker(f, raw=False) | ||
|
||
for unpacked in unpacker: | ||
|
||
if user: | ||
unpacked["metadata"]["user"] = user | ||
|
||
if collection: | ||
unpacked["metadata"]["collection"] = collection | ||
|
||
|
||
if unpacked[0] == "t": | ||
await t_queue.put(unpacked[1]) | ||
t_counts += 1 | ||
else: | ||
if unpacked[0] == "ge": | ||
await ge_queue.put(unpacked[1]) | ||
ge_counts += 1 | ||
|
||
async def run(**args): | ||
|
||
ge_q = asyncio.Queue() | ||
t_q = asyncio.Queue() | ||
|
||
load_task = asyncio.create_task( | ||
loader( | ||
ge_queue=ge_q, t_queue=t_q, | ||
path=args["input_file"], format=args["format"], | ||
user=args["user"], collection=args["collection"], | ||
) | ||
|
||
) | ||
|
||
ge_task = asyncio.create_task( | ||
load_ge( | ||
queue=ge_q, url=args["url"] + "api/v1/" | ||
) | ||
) | ||
|
||
triples_task = asyncio.create_task( | ||
load_triples( | ||
queue=t_q, url=args["url"] + "api/v1/" | ||
) | ||
) | ||
|
||
stats_task = asyncio.create_task(stats()) | ||
|
||
await load_task | ||
await triples_task | ||
await ge_task | ||
await stats_task | ||
|
||
async def main(): | ||
|
||
parser = argparse.ArgumentParser( | ||
prog='tg-load-pdf', | ||
description=__doc__, | ||
) | ||
|
||
default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") | ||
default_user = "trustgraph" | ||
collection = "default" | ||
|
||
parser.add_argument( | ||
'-u', '--url', | ||
default=default_url, | ||
help=f'TrustGraph API URL (default: {default_url})', | ||
) | ||
|
||
parser.add_argument( | ||
'-i', '--input-file', | ||
# Make it mandatory, difficult to over-write an existing file | ||
required=True, | ||
help=f'Output file' | ||
) | ||
|
||
parser.add_argument( | ||
'--format', | ||
default="msgpack", | ||
choices=["msgpack", "json"], | ||
help=f'Output format (default: msgpack)', | ||
) | ||
|
||
parser.add_argument( | ||
'--user', | ||
help=f'User ID to load as (default: from input)' | ||
) | ||
|
||
parser.add_argument( | ||
'--collection', | ||
help=f'Collection ID to load as (default: from input)' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
await run(**vars(args)) | ||
|
||
asyncio.run(main()) | ||
|
Oops, something went wrong.