Skip to content

Commit

Permalink
Feature/gateway auth (#186)
Browse files Browse the repository at this point in the history
* Added auth module, just a simple token at this stage
* Pass auth token GATEWAY_SECRET through
* Auth token not mandatory, can be provided in env var
  • Loading branch information
cybermaggedon authored Dec 2, 2024
1 parent 6d200c7 commit 1b9c6be
Show file tree
Hide file tree
Showing 18 changed files with 126 additions and 33 deletions.
5 changes: 5 additions & 0 deletions templates/components/trustgraph.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ local prompt = import "prompt-template.jsonnet";

create:: function(engine)

local envSecrets = engine.envSecrets("gateway-secret")
.with_env_var("GATEWAY_SECRET", "gateway-secret");

local port = $["api-gateway-port"];

local container =
Expand All @@ -29,6 +32,7 @@ local prompt = import "prompt-template.jsonnet";
"--port",
std.toString(port),
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "256M")
.with_reservations("0.1", "256M")
.with_port(8000, 8000, "metrics")
Expand All @@ -44,6 +48,7 @@ local prompt = import "prompt-template.jsonnet";
.with_port(port, port, "api");

engine.resources([
envSecrets,
containerSet,
service,
])
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import MultiResponseServiceEndpoint

class AgentEndpoint(MultiResponseServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(AgentEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=AgentResponse,
endpoint_path="/api/v1/agent",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
22 changes: 22 additions & 0 deletions trustgraph-flow/trustgraph/api/gateway/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

class Authenticator:

def __init__(self, token=None, allow_all=False):

if not allow_all and token is None:
raise RuntimeError("Need a token")

if not allow_all and token == "":
raise RuntimeError("Need a token")

self.token = token
self.allow_all = allow_all

def permitted(self, token, roles):

if self.allow_all: return True

if self.token != token: return False

return True

3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/dbpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import ServiceEndpoint

class DbpediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(DbpediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=LookupResponse,
endpoint_path="/api/v1/dbpedia",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import ServiceEndpoint

class EmbeddingsEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(EmbeddingsEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=EmbeddingsResponse,
endpoint_path="/api/v1/embeddings",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/encyclopedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import ServiceEndpoint

class EncyclopediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(EncyclopediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=LookupResponse,
endpoint_path="/api/v1/encyclopedia",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
24 changes: 16 additions & 8 deletions trustgraph-flow/trustgraph/api/gateway/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
request_queue, request_schema,
response_queue, response_schema,
endpoint_path,
auth,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
Expand All @@ -36,6 +37,9 @@ def __init__(

self.path = endpoint_path
self.timeout = timeout
self.auth = auth

self.operation = "service"

async def start(self):

Expand All @@ -58,14 +62,24 @@ async def handle(self, request):

id = str(uuid.uuid4())

try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
if tokens[0] != "Bearer":
return web.HTTPUnauthorized()
token = tokens[1]
except:
token = ""

if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()

try:

data = await request.json()

q = await self.sub.subscribe(id)

print(data)

await self.pub.send(
id,
self.to_request(data),
Expand All @@ -76,8 +90,6 @@ async def handle(self, request):
except:
raise RuntimeError("Timeout waiting for response")

print(resp)

if resp.error:
return web.json_response(
{ "error": resp.error.message }
Expand Down Expand Up @@ -110,8 +122,6 @@ async def handle(self, request):

q = await self.sub.subscribe(id)

print(data)

await self.pub.send(
id,
self.to_request(data),
Expand All @@ -126,8 +136,6 @@ async def handle(self, request):
except:
raise RuntimeError("Timeout waiting for response")

print(resp)

if resp.error:
return web.json_response(
{ "error": resp.error.message }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

class GraphEmbeddingsLoadEndpoint(SocketEndpoint):

def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
):

super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)

self.pulsar_host=pulsar_host
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

class GraphEmbeddingsStreamEndpoint(SocketEndpoint):

def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
):

super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path
endpoint_path=path, auth=auth,
)

self.pulsar_host=pulsar_host
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import ServiceEndpoint

class GraphRagEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(GraphRagEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=GraphRagResponse,
endpoint_path="/api/v1/graph-rag",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/internet_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . endpoint import ServiceEndpoint

class InternetSearchEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(InternetSearchEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=LookupResponse,
endpoint_path="/api/v1/internet-search",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
3 changes: 2 additions & 1 deletion trustgraph-flow/trustgraph/api/gateway/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . endpoint import ServiceEndpoint

class PromptEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
def __init__(self, pulsar_host, timeout, auth):

super(PromptEndpoint, self).__init__(
pulsar_host=pulsar_host,
Expand All @@ -18,6 +18,7 @@ def __init__(self, pulsar_host, timeout):
response_schema=PromptResponse,
endpoint_path="/api/v1/prompt",
timeout=timeout,
auth=auth,
)

def to_request(self, body):
Expand Down
37 changes: 33 additions & 4 deletions trustgraph-flow/trustgraph/api/gateway/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . auth import Authenticator

logger = logging.getLogger("api")
logger.setLevel(logging.INFO)

default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088
default_api_token = os.getenv("GATEWAY_SECRET", "")

class Api:

Expand All @@ -66,45 +68,66 @@ def __init__(self, **config):
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)

api_token = config.get("api_token", default_api_token)

# Token not set, or token equal empty string means no auth
if api_token:
self.auth = Authenticator(token=api_token)
else:
self.auth = Authenticator(allow_all=True)

self.endpoints = [
TextCompletionEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
PromptEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
GraphRagEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesQueryEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EmbeddingsEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
AgentEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EncyclopediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
DbpediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
InternetSearchEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host
pulsar_host=self.pulsar_host,
auth = self.auth,
),
]

Expand Down Expand Up @@ -254,6 +277,12 @@ def run():
help=f'API request timeout in seconds (default: {default_timeout})',
)

parser.add_argument(
'--api-token',
default=default_api_token,
help=f'Secret API token (default: no auth)',
)

parser.add_argument(
'-l', '--log-level',
type=LogLevel,
Expand Down
Loading

0 comments on commit 1b9c6be

Please sign in to comment.