From 1b9c6be4fc3175e90c11719d820ddc3b146cd33c Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 2 Dec 2024 19:57:21 +0000 Subject: [PATCH] Feature/gateway auth (#186) * Added auth module, just a simple token at this stage * Pass auth token GATEWAY_SECRET through * Auth token not mandatory, can be provided in env var --- templates/components/trustgraph.jsonnet | 5 +++ .../trustgraph/api/gateway/agent.py | 3 +- .../trustgraph/api/gateway/auth.py | 22 +++++++++++ .../trustgraph/api/gateway/dbpedia.py | 3 +- .../trustgraph/api/gateway/embeddings.py | 3 +- .../trustgraph/api/gateway/encyclopedia.py | 3 +- .../trustgraph/api/gateway/endpoint.py | 24 ++++++++---- .../api/gateway/graph_embeddings_load.py | 6 ++- .../api/gateway/graph_embeddings_stream.py | 6 ++- .../trustgraph/api/gateway/graph_rag.py | 3 +- .../trustgraph/api/gateway/internet_search.py | 3 +- .../trustgraph/api/gateway/prompt.py | 3 +- .../trustgraph/api/gateway/service.py | 37 +++++++++++++++++-- .../trustgraph/api/gateway/socket.py | 24 ++++++++++-- .../trustgraph/api/gateway/text_completion.py | 3 +- .../trustgraph/api/gateway/triples_load.py | 4 +- .../trustgraph/api/gateway/triples_query.py | 3 +- .../trustgraph/api/gateway/triples_stream.py | 4 +- 18 files changed, 126 insertions(+), 33 deletions(-) create mode 100644 trustgraph-flow/trustgraph/api/gateway/auth.py diff --git a/templates/components/trustgraph.jsonnet b/templates/components/trustgraph.jsonnet index 6c60921c..31ae420e 100644 --- a/templates/components/trustgraph.jsonnet +++ b/templates/components/trustgraph.jsonnet @@ -15,6 +15,9 @@ local prompt = import "prompt-template.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("gateway-secret") + .with_env_var("GATEWAY_SECRET", "gateway-secret"); + local port = $["api-gateway-port"]; local container = @@ -29,6 +32,7 @@ local prompt = import "prompt-template.jsonnet"; "--port", std.toString(port), ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "256M") .with_reservations("0.1", "256M") .with_port(8000, 8000, "metrics") @@ -44,6 +48,7 @@ local prompt = import "prompt-template.jsonnet"; .with_port(port, port, "api"); engine.resources([ + envSecrets, containerSet, service, ]) diff --git a/trustgraph-flow/trustgraph/api/gateway/agent.py b/trustgraph-flow/trustgraph/api/gateway/agent.py index 28a1e185..40586133 100644 --- a/trustgraph-flow/trustgraph/api/gateway/agent.py +++ b/trustgraph-flow/trustgraph/api/gateway/agent.py @@ -6,7 +6,7 @@ from . endpoint import MultiResponseServiceEndpoint class AgentEndpoint(MultiResponseServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(AgentEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=AgentResponse, endpoint_path="/api/v1/agent", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/auth.py b/trustgraph-flow/trustgraph/api/gateway/auth.py new file mode 100644 index 00000000..a693ca32 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/auth.py @@ -0,0 +1,22 @@ + +class Authenticator: + + def __init__(self, token=None, allow_all=False): + + if not allow_all and token is None: + raise RuntimeError("Need a token") + + if not allow_all and token == "": + raise RuntimeError("Need a token") + + self.token = token + self.allow_all = allow_all + + def permitted(self, token, roles): + + if self.allow_all: return True + + if self.token != token: return False + + return True + diff --git a/trustgraph-flow/trustgraph/api/gateway/dbpedia.py b/trustgraph-flow/trustgraph/api/gateway/dbpedia.py index 0ccb3d6b..4fa7336b 100644 --- a/trustgraph-flow/trustgraph/api/gateway/dbpedia.py +++ b/trustgraph-flow/trustgraph/api/gateway/dbpedia.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class DbpediaEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(DbpediaEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=LookupResponse, endpoint_path="/api/v1/dbpedia", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/embeddings.py b/trustgraph-flow/trustgraph/api/gateway/embeddings.py index b5fcc0a4..7c4b578d 100644 --- a/trustgraph-flow/trustgraph/api/gateway/embeddings.py +++ b/trustgraph-flow/trustgraph/api/gateway/embeddings.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class EmbeddingsEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(EmbeddingsEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=EmbeddingsResponse, endpoint_path="/api/v1/embeddings", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py b/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py index e379d7d4..c6041cb2 100644 --- a/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py +++ b/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class EncyclopediaEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(EncyclopediaEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=LookupResponse, endpoint_path="/api/v1/encyclopedia", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/endpoint.py b/trustgraph-flow/trustgraph/api/gateway/endpoint.py index 075e4a0e..af7a5070 100644 --- a/trustgraph-flow/trustgraph/api/gateway/endpoint.py +++ b/trustgraph-flow/trustgraph/api/gateway/endpoint.py @@ -19,6 +19,7 @@ def __init__( request_queue, request_schema, response_queue, response_schema, endpoint_path, + auth, subscription="api-gateway", consumer_name="api-gateway", timeout=600, ): @@ -36,6 +37,9 @@ def __init__( self.path = endpoint_path self.timeout = timeout + self.auth = auth + + self.operation = "service" async def start(self): @@ -58,14 +62,24 @@ async def handle(self, request): id = str(uuid.uuid4()) + try: + ht = request.headers["Authorization"] + tokens = ht.split(" ", 2) + if tokens[0] != "Bearer": + return web.HTTPUnauthorized() + token = tokens[1] + except: + token = "" + + if not self.auth.permitted(token, self.operation): + return web.HTTPUnauthorized() + try: data = await request.json() q = await self.sub.subscribe(id) - print(data) - await self.pub.send( id, self.to_request(data), @@ -76,8 +90,6 @@ async def handle(self, request): except: raise RuntimeError("Timeout waiting for response") - print(resp) - if resp.error: return web.json_response( { "error": resp.error.message } @@ -110,8 +122,6 @@ async def handle(self, request): q = await self.sub.subscribe(id) - print(data) - await self.pub.send( id, self.to_request(data), @@ -126,8 +136,6 @@ async def handle(self, request): except: raise RuntimeError("Timeout waiting for response") - print(resp) - if resp.error: return web.json_response( { "error": resp.error.message } diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py index 3cc3f533..15efdf5b 100644 --- a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py +++ b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py @@ -14,10 +14,12 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"): + def __init__( + self, pulsar_host, auth, path="/api/v1/load/graph-embeddings", + ): super(GraphEmbeddingsLoadEndpoint, self).__init__( - endpoint_path=path + endpoint_path=path, auth=auth, ) self.pulsar_host=pulsar_host diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py index 978684cf..7f3e5e18 100644 --- a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py +++ b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py @@ -12,10 +12,12 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"): + def __init__( + self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings" + ): super(GraphEmbeddingsStreamEndpoint, self).__init__( - endpoint_path=path + endpoint_path=path, auth=auth, ) self.pulsar_host=pulsar_host diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_rag.py b/trustgraph-flow/trustgraph/api/gateway/graph_rag.py index 1381dc23..d33090ca 100644 --- a/trustgraph-flow/trustgraph/api/gateway/graph_rag.py +++ b/trustgraph-flow/trustgraph/api/gateway/graph_rag.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class GraphRagEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(GraphRagEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=GraphRagResponse, endpoint_path="/api/v1/graph-rag", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/internet_search.py b/trustgraph-flow/trustgraph/api/gateway/internet_search.py index c84ed82a..f55a4a3e 100644 --- a/trustgraph-flow/trustgraph/api/gateway/internet_search.py +++ b/trustgraph-flow/trustgraph/api/gateway/internet_search.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class InternetSearchEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(InternetSearchEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=LookupResponse, endpoint_path="/api/v1/internet-search", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/prompt.py b/trustgraph-flow/trustgraph/api/gateway/prompt.py index e02effb9..d19005bc 100644 --- a/trustgraph-flow/trustgraph/api/gateway/prompt.py +++ b/trustgraph-flow/trustgraph/api/gateway/prompt.py @@ -8,7 +8,7 @@ from . endpoint import ServiceEndpoint class PromptEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(PromptEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -18,6 +18,7 @@ def __init__(self, pulsar_host, timeout): response_schema=PromptResponse, endpoint_path="/api/v1/prompt", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/service.py b/trustgraph-flow/trustgraph/api/gateway/service.py index dcdd9779..a25dd9dc 100755 --- a/trustgraph-flow/trustgraph/api/gateway/service.py +++ b/trustgraph-flow/trustgraph/api/gateway/service.py @@ -45,6 +45,7 @@ from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint from . triples_load import TriplesLoadEndpoint from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint +from . auth import Authenticator logger = logging.getLogger("api") logger.setLevel(logging.INFO) @@ -52,6 +53,7 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") default_timeout = 600 default_port = 8088 +default_api_token = os.getenv("GATEWAY_SECRET", "") class Api: @@ -66,45 +68,66 @@ def __init__(self, **config): self.timeout = int(config.get("timeout", default_timeout)) self.pulsar_host = config.get("pulsar_host", default_pulsar_host) + api_token = config.get("api_token", default_api_token) + + # Token not set, or token equal empty string means no auth + if api_token: + self.auth = Authenticator(token=api_token) + else: + self.auth = Authenticator(allow_all=True) + self.endpoints = [ TextCompletionEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), PromptEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), GraphRagEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), TriplesQueryEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), EmbeddingsEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), AgentEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), EncyclopediaEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), DbpediaEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), InternetSearchEndpoint( pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, ), TriplesStreamEndpoint( - pulsar_host=self.pulsar_host + pulsar_host=self.pulsar_host, + auth = self.auth, ), GraphEmbeddingsStreamEndpoint( - pulsar_host=self.pulsar_host + pulsar_host=self.pulsar_host, + auth = self.auth, ), TriplesLoadEndpoint( - pulsar_host=self.pulsar_host + pulsar_host=self.pulsar_host, + auth = self.auth, ), GraphEmbeddingsLoadEndpoint( - pulsar_host=self.pulsar_host + pulsar_host=self.pulsar_host, + auth = self.auth, ), ] @@ -254,6 +277,12 @@ def run(): help=f'API request timeout in seconds (default: {default_timeout})', ) + parser.add_argument( + '--api-token', + default=default_api_token, + help=f'Secret API token (default: no auth)', + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, diff --git a/trustgraph-flow/trustgraph/api/gateway/socket.py b/trustgraph-flow/trustgraph/api/gateway/socket.py index 235bfd21..869792b7 100644 --- a/trustgraph-flow/trustgraph/api/gateway/socket.py +++ b/trustgraph-flow/trustgraph/api/gateway/socket.py @@ -11,11 +11,12 @@ class SocketEndpoint: def __init__( - self, - endpoint_path="/api/v1/socket", + self, endpoint_path, auth, ): self.path = endpoint_path + self.auth = auth + self.operation = "socket" async def listener(self, ws, running): @@ -43,18 +44,33 @@ async def async_thread(self, ws, running): async def handle(self, request): + try: + token = request.query['token'] + except: + token = "" + + if not self.auth.permitted(token, self.operation): + return web.HTTPUnauthorized() + running = Running() ws = web.WebSocketResponse() await ws.prepare(request) task = asyncio.create_task(self.async_thread(ws, running)) - await self.listener(ws, running) + try: - await task + await self.listener(ws, running) + + except Exception as e: + print(e, flush=True) running.stop() + await ws.close() + + await task + return ws async def start(self): diff --git a/trustgraph-flow/trustgraph/api/gateway/text_completion.py b/trustgraph-flow/trustgraph/api/gateway/text_completion.py index 04dbc9c8..d9f69b7e 100644 --- a/trustgraph-flow/trustgraph/api/gateway/text_completion.py +++ b/trustgraph-flow/trustgraph/api/gateway/text_completion.py @@ -6,7 +6,7 @@ from . endpoint import ServiceEndpoint class TextCompletionEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(TextCompletionEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -16,6 +16,7 @@ def __init__(self, pulsar_host, timeout): response_schema=TextCompletionResponse, endpoint_path="/api/v1/text-completion", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_load.py b/trustgraph-flow/trustgraph/api/gateway/triples_load.py index d835a363..7f4561b1 100644 --- a/trustgraph-flow/trustgraph/api/gateway/triples_load.py +++ b/trustgraph-flow/trustgraph/api/gateway/triples_load.py @@ -14,10 +14,10 @@ class TriplesLoadEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, path="/api/v1/load/triples"): + def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"): super(TriplesLoadEndpoint, self).__init__( - endpoint_path=path + endpoint_path=path, auth=auth, ) self.pulsar_host=pulsar_host diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_query.py b/trustgraph-flow/trustgraph/api/gateway/triples_query.py index 8b4192d8..9c5939c8 100644 --- a/trustgraph-flow/trustgraph/api/gateway/triples_query.py +++ b/trustgraph-flow/trustgraph/api/gateway/triples_query.py @@ -7,7 +7,7 @@ from . serialize import to_value, serialize_subgraph class TriplesQueryEndpoint(ServiceEndpoint): - def __init__(self, pulsar_host, timeout): + def __init__(self, pulsar_host, timeout, auth): super(TriplesQueryEndpoint, self).__init__( pulsar_host=pulsar_host, @@ -17,6 +17,7 @@ def __init__(self, pulsar_host, timeout): response_schema=TriplesQueryResponse, endpoint_path="/api/v1/triples-query", timeout=timeout, + auth=auth, ) def to_request(self, body): diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_stream.py b/trustgraph-flow/trustgraph/api/gateway/triples_stream.py index e8b538a4..6ecd2bdb 100644 --- a/trustgraph-flow/trustgraph/api/gateway/triples_stream.py +++ b/trustgraph-flow/trustgraph/api/gateway/triples_stream.py @@ -12,10 +12,10 @@ class TriplesStreamEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, path="/api/v1/stream/triples"): + def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"): super(TriplesStreamEndpoint, self).__init__( - endpoint_path=path + endpoint_path=path, auth=auth, ) self.pulsar_host=pulsar_host