Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/async problem #190

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion trustgraph-flow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
"ibis",
"jsonschema",
"aiohttp",
"aiopulsar-py",
"pinecone[grpc]",
],
scripts=[
Expand Down
44 changes: 17 additions & 27 deletions trustgraph-flow/trustgraph/api/gateway/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,10 @@ def __init__(

self.operation = "service"

async def start(self, client):
async def start(self):

self.pub_task = asyncio.create_task(self.pub.run(client))
self.sub_task = asyncio.create_task(self.sub.run(client))

async def join(self):

await self.pub_task
await self.sub_task
self.pub.start()
self.sub.start()

def add_routes(self, app):

Expand Down Expand Up @@ -87,29 +82,25 @@ async def handle(self, request):

print(data)

q = await self.sub.subscribe(id)
q = self.sub.subscribe(id)

await self.pub.send(
id,
self.to_request(data),
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
print("Request sent")

try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")

print("Response got")
print(resp)

if resp.error:
print("Error")
return web.json_response(
{ "error": resp.error.message }
)

print("Send response")

return web.json_response(
self.from_response(resp)
)
Expand All @@ -122,7 +113,7 @@ async def handle(self, request):
)

finally:
await self.sub.unsubscribe(id)
self.sub.unsubscribe(id)


class MultiResponseServiceEndpoint(ServiceEndpoint):
Expand All @@ -135,20 +126,19 @@ async def handle(self, request):

data = await request.json()

q = await self.sub.subscribe(id)
q = self.sub.subscribe(id)

await self.pub.send(
id,
self.to_request(data),
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)

# Keeps looking at responses...

while True:

try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout waiting for response")

if resp.error:
Expand All @@ -173,4 +163,4 @@ async def handle(self, request):
)

finally:
await self.sub.unsubscribe(id)
self.sub.unsubscribe(id)
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ def __init__(
schema=JsonSchema(GraphEmbeddings)
)

async def start(self, client):
async def start(self):

self.task = asyncio.create_task(
self.publisher.run(client)
)
self.publisher.start()

async def listener(self, ws, running):

Expand All @@ -56,7 +54,7 @@ async def listener(self, ws, running):
vectors=data["vectors"],
)

await self.publisher.send(None, elt)
self.publisher.send(None, elt)


running.stop()
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid

Expand Down Expand Up @@ -28,31 +29,29 @@ def __init__(
schema=JsonSchema(GraphEmbeddings)
)

async def start(self, client):
async def start(self):

self.task = asyncio.create_task(
self.subscriber.run(client)
)
self.subscriber.start()

async def async_thread(self, ws, running):

id = str(uuid.uuid4())

q = await self.subscriber.subscribe_all(id)
q = self.subscriber.subscribe_all(id)

while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))

except TimeoutError:
except queue.Empty:
continue

except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break

await self.subscriber.unsubscribe_all(id)
self.subscriber.unsubscribe_all(id)

running.stop()

50 changes: 32 additions & 18 deletions trustgraph-flow/trustgraph/api/gateway/publisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

import asyncio
import queue
import time
import pulsar
import threading

class Publisher:

Expand All @@ -8,32 +11,43 @@ def __init__(self, pulsar_host, topic, schema=None, max_size=10,
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled

async def run(self, client):
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()

def run(self):

while True:

try:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()

if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)

client = pulsar.Client(
self.pulsar_host,
)

producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)

while True:

id, item = self.q.get()

if id:
producer.send(item, { "id": id })
else:
producer.send(item)

except Exception as e:
print("Exception:", e, flush=True)

# If handler drops out, sleep a retry
await asyncio.sleep(2)
time.sleep(2)

async def send(self, id, msg):
await self.q.put((id, msg))
def send(self, id, msg):
self.q.put((id, msg))
39 changes: 9 additions & 30 deletions trustgraph-flow/trustgraph/api/gateway/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
import os
import base64
import aiopulsar

import pulsar
from pulsar.schema import JsonSchema
Expand Down Expand Up @@ -167,7 +166,8 @@ async def load_document(self, request):
# content is valid base64
doc = base64.b64decode(data["data"])

resp = await self.document_out.send(
resp = await asyncio.to_thread(
self.document_out.send,
None,
Document(
metadata=Metadata(
Expand Down Expand Up @@ -212,7 +212,8 @@ async def load_text(self, request):
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)

resp = await self.text_out.send(
resp = asyncio.to_thread(
self.text_out.send,
None,
TextDocument(
metadata=Metadata(
Expand All @@ -238,35 +239,13 @@ async def load_text(self, request):
{ "error": str(e) }
)

async def run_endpoints(self):

async with aiopulsar.connect(self.pulsar_host) as client:

for ep in self.endpoints:
await ep.start(client)

self.doc_ingest_pub_task = asyncio.create_task(
self.document_out.run(client)
)

self.text_ingest_pub_task = asyncio.create_task(
self.text_out.run(client)
)

print("Endpoints are running...")

# They never exit
for ep in self.endpoints:
await ep.join()

await self.doc_ingest_pub_task
await self.text_ingest_pub_task

print("Endpoints are stopped.")

async def app_factory(self):

self.endpoint_task = asyncio.create_task(self.run_endpoints())
for ep in self.endpoints:
await ep.start()

self.document_out.start()
self.text_out.start()

return self.app

Expand Down
6 changes: 0 additions & 6 deletions trustgraph-flow/trustgraph/api/gateway/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ async def handle(self, request):
async def start(self):
pass

async def join(self):

# Nothing to wait for
while True:
await asyncio.sleep(100)

def add_routes(self, app):

app.add_routes([
Expand Down
Loading
Loading