Skip to content

Commit

Permalink
API supports doc & text load (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
cybermaggedon authored Nov 21, 2024
1 parent a1e0edd commit dc0f54f
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 17 deletions.
1 change: 0 additions & 1 deletion trustgraph-cli/scripts/tg-load-text
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Loads a text document into TrustGraph processing.

import pulsar
from pulsar.schema import JsonSchema
import base64
import hashlib
import argparse
import os
Expand Down
159 changes: 143 additions & 16 deletions trustgraph-flow/trustgraph/api/gateway/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import uuid
import os
import base64

import pulsar
from pulsar.asyncio import Client
Expand All @@ -32,6 +33,8 @@
from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient

from ... schema import Value, Metadata, Document, TextDocument, Triple

from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
Expand All @@ -44,7 +47,7 @@
from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue

from ... schema import TriplesQueryRequest, TriplesQueryResponse, Value
from ... schema import TriplesQueryRequest, TriplesQueryResponse
from ... schema import triples_request_queue
from ... schema import triples_response_queue

Expand All @@ -56,20 +59,40 @@
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue

from ... schema import document_ingest_queue, text_ingest_queue

logger = logging.getLogger("api")
logger.setLevel(logging.INFO)

default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088

def to_value(x):
if x.startswith("http:") or x.startswith("https:"):
return Value(value=x, is_uri=True)
else:
return Value(value=x, is_uri=True)

def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]

class Publisher:

def __init__(self, pulsar_host, topic, schema=None, max_size=10):
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled

async def run(self):

Expand All @@ -80,10 +103,16 @@ async def run(self):
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
await producer.send(item, { "id": id })

if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)

except Exception as e:
print("Exception:", e, flush=True)

Expand Down Expand Up @@ -139,7 +168,10 @@ class Api:

def __init__(self, **config):

self.app = web.Application(middlewares=[])
self.app = web.Application(
middlewares=[],
client_max_size=256 * 1024 * 1024
)

self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
Expand Down Expand Up @@ -211,13 +243,27 @@ def __init__(self, **config):
JsonSchema(EmbeddingsResponse)
)

self.document_out = Publisher(
self.pulsar_host, document_ingest_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)

self.text_out = Publisher(
self.pulsar_host, text_ingest_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)

self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
web.post("/api/v1/graph-rag", self.graph_rag),
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text),
])

async def llm(self, request):
Expand Down Expand Up @@ -368,26 +414,17 @@ async def triples_query(self, request):
q = await self.triples_query_in.subscribe(id)

if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"):
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
s = to_value(data["s"])
else:
s = None

if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"):
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
p = to_value(data["p"])
else:
p = None

if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"):
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
o = to_value(data["o"])
else:
o = None

Expand Down Expand Up @@ -537,6 +574,92 @@ async def embeddings(self, request):
finally:
await self.embeddings_in.unsubscribe(id)

async def load_document(self, request):

try:

data = await request.json()

if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []

# Doing a base64 decode/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])

resp = await self.document_out.send(
None,
Document(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
)

print("Document loaded.")

return web.json_response(
{ }
)

except Exception as e:
logging.error(f"Exception: {e}")

return web.json_response(
{ "error": str(e) }
)

async def load_text(self, request):

try:

data = await request.json()

if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []

if "charset" in data:
charset = data["charset"]
else:
charset = "utf-8"

# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)

resp = await self.text_out.send(
None,
TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text,
)
)

print("Text document loaded.")

return web.json_response(
{ }
)

except Exception as e:
logging.error(f"Exception: {e}")

return web.json_response(
{ "error": str(e) }
)

async def app_factory(self):

self.llm_pub_task = asyncio.create_task(self.llm_in.run())
Expand Down Expand Up @@ -565,6 +688,10 @@ async def app_factory(self):
self.embeddings_out.run()
)

self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())

self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())

return self.app

def run(self):
Expand Down

0 comments on commit dc0f54f

Please sign in to comment.