From fad9672eeac0dbdc150a8f358be7420bf6b89a5f Mon Sep 17 00:00:00 2001 From: Di Wang Date: Sat, 17 Feb 2024 17:44:10 +1100 Subject: [PATCH] Add slack toolkit Signed-off-by: Di Wang --- app/chains/supervisor.py | 9 ++++++--- app/graph.py | 18 ++++++++++++------ app/server.py | 4 ++-- app/tools/slack_toolkit.py | 3 +++ 4 files changed, 23 insertions(+), 11 deletions(-) create mode 100644 app/tools/slack_toolkit.py diff --git a/app/chains/supervisor.py b/app/chains/supervisor.py index 2e08f84..a649e38 100644 --- a/app/chains/supervisor.py +++ b/app/chains/supervisor.py @@ -1,3 +1,6 @@ +import json +from typing import Mapping + from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder @@ -14,8 +17,8 @@ ) -def build_supervisor_chain(members): - options = ["FINISH"] + members +def build_supervisor_chain(members: Mapping[str, str]): + options = ["FINISH"] + list(members.keys()) function_def = { "name": "route", "description": "Select the next role.", @@ -43,7 +46,7 @@ def build_supervisor_chain(members): " Or should we FINISH? Select one of: {options}", ), ] - ).partial(options=str(options), members=", ".join(members)) + ).partial(options=str(options), members=json.dumps(members)) return ( prompt diff --git a/app/graph.py b/app/graph.py index 0f581ee..a23f0cf 100644 --- a/app/graph.py +++ b/app/graph.py @@ -1,5 +1,6 @@ import functools import operator +from collections.abc import Mapping from typing import Annotated, Sequence, TypedDict from langchain.agents import create_openai_tools_agent, AgentExecutor @@ -16,6 +17,7 @@ from app.tools.datetime_provider import datetime_provider from app.tools.random_number import random_number from app.tools.random_select import random_select +from app.tools.slack_toolkit import slack_toolkit from app.tools.webrca_create import webrca_create from app.tools.duckduckgo_search import duckduckgo_search from app.tools.slack_searcher import slack_searcher @@ -68,6 +70,10 @@ class AgentState(TypedDict): "tools": [duckduckgo_search], "system_prompt": "You are a search engine for generic questions.", }, + "SlackToolkit": { + "tools": slack_toolkit.get_tools(), + "system_prompt": "You are a slack toolkit.", + }, "SlackSearcher": { "tools": [slack_searcher], "system_prompt": "You are a slack searcher.", @@ -75,13 +81,14 @@ class AgentState(TypedDict): "DatetimeProvider": { "tools": [datetime_provider], "system_prompt": "You are a datetime provider.", - } + }, } +SUPERVISOR_MEMBERS = {k: v["system_prompt"] for k, v in GRAPH.items()} + def build_graph() -> Pregel: - members = list(GRAPH.keys()) - supervisor_chain = build_supervisor_chain(members) + supervisor_chain = build_supervisor_chain(SUPERVISOR_MEMBERS) workflow = StateGraph(AgentState) for member, config in GRAPH.items(): @@ -89,11 +96,10 @@ def build_graph() -> Pregel: workflow.add_node(member, functools.partial(agent_node, agent=agent, name=member)) workflow.add_node(SUPERVISOR_NAME, supervisor_chain) - for member in members: + for member in GRAPH: workflow.add_edge(member, SUPERVISOR_NAME) - conditional_map = {k: k for k in members} - conditional_map["FINISH"] = END + conditional_map = {k: k for k in GRAPH} | {"FINISH": END} workflow.add_conditional_edges(SUPERVISOR_NAME, lambda x: x["next"], conditional_map) workflow.set_entry_point(SUPERVISOR_NAME) diff --git a/app/server.py b/app/server.py index 329abfe..a38c7af 100644 --- a/app/server.py +++ b/app/server.py @@ -8,7 +8,7 @@ from app.chains.supervisor import build_supervisor_chain from app.agents.webrca_create import webrca_create_agent_executor from app.dependencies.ollama_chat_model import ollama_chat_model -from app.graph import graph +from app.graph import graph, SUPERVISOR_MEMBERS from app.routers import slack app = FastAPI() @@ -51,7 +51,7 @@ async def redirect_root_to_docs(): add_routes( app, - build_supervisor_chain(["SlackSummarizer", "SlackSearcher", "WebRCA"]), + build_supervisor_chain(SUPERVISOR_MEMBERS), path="/supervisor", ) diff --git a/app/tools/slack_toolkit.py b/app/tools/slack_toolkit.py new file mode 100644 index 0000000..8454c31 --- /dev/null +++ b/app/tools/slack_toolkit.py @@ -0,0 +1,3 @@ +from langchain_community.agent_toolkits import SlackToolkit + +slack_toolkit = SlackToolkit()