-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.ts
135 lines (102 loc) · 4.09 KB
/
agent.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import { OpenAIEmbeddings } from "@langchain/openai";
import { ChatAnthropic } from "@langchain/anthropic";
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { StateGraph } from "@langchain/langgraph";
import { Annotation } from "@langchain/langgraph";
import { tool } from "@langchain/core/tools";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { MongoDBSaver } from "@langchain/langgraph-checkpoint-mongodb";
import { MongoDBAtlasVectorSearch } from "@langchain/mongodb";
import { MongoClient } from "mongodb";
import { z } from "zod";
import "dotenv/config";
export async function callAgent(client: MongoClient, query: string, thread_id: string) {
const dbName = "hr_database";
const db = client.db(dbName);
const collection = db.collection("employees");
const GraphState = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: (x, y) => x.concat(y),
}),
});
const employeeLookupTool = tool(
async ({ query, n = 10 }) => {
console.log("Employee lookup tool called");
const dbConfig = {
collection: collection,
indexName: "vector_index",
textKey: "embedding_text",
embeddingKey: "embedding",
};
const vectorStore = new MongoDBAtlasVectorSearch(
new OpenAIEmbeddings(),
dbConfig
);
const result = await vectorStore.similaritySearchWithScore(query, n);
return JSON.stringify(result);
},
{
name: "employee_lookup",
description: "Gathers employee details from the HR database",
schema: z.object({
query: z.string().describe("The search query"),
n: z
.number()
.optional()
.default(10)
.describe("Number of results to return"),
}),
}
);
const tools = [employeeLookupTool];
const toolNode = new ToolNode<typeof GraphState.State>(tools);
const model = new ChatAnthropic({
model: "claude-3-5-sonnet-20240620",
temperature: 0,
}).bindTools(tools);
function shouldContinue(state: typeof GraphState.State) {
const messages = state.messages;
const lastMessage = messages[messages.length - 1] as AIMessage;
if (lastMessage.tool_calls?.length) {
return "tools";
}
return "__end__";
}
async function callModel(state: typeof GraphState.State) {
const prompt = ChatPromptTemplate.fromMessages([
[
"system",
`You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress towards answering the question. If you are unable to fully answer, that's OK, another assistant with different tools will help where you left off. Execute what you can to make progress. If you or any of the other assistants have the final answer or deliverable, prefix your response with FINAL ANSWER so the team knows to stop. You have access to the following tools: {tool_names}.\n{system_message}\nCurrent time: {time}.`,
],
new MessagesPlaceholder("messages"),
]);
const formattedPrompt = await prompt.formatMessages({
system_message: "You are helpful HR Chatbot Agent.",
time: new Date().toISOString(),
tool_names: tools.map((tool) => tool.name).join(", "),
messages: state.messages,
});
const result = await model.invoke(formattedPrompt);
return { messages: [result] };
}
const workflow = new StateGraph(GraphState)
.addNode("agent", callModel)
.addNode("tools", toolNode)
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent");
const checkpointer = new MongoDBSaver({ client, dbName });
const app = workflow.compile({ checkpointer });
const finalState = await app.invoke(
{
messages: [new HumanMessage(query)],
},
{ recursionLimit: 15, configurable: { thread_id: thread_id } }
);
console.log(finalState.messages[finalState.messages.length - 1].content);
return finalState.messages[finalState.messages.length - 1].content;
}