Skip to content

Commit

Permalink
fix: tool choice run same tool will error (#3502)
Browse files Browse the repository at this point in the history
  • Loading branch information
c121914yu authored Dec 31, 2024
1 parent b2fdefd commit b75e807
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 60 deletions.
143 changes: 86 additions & 57 deletions packages/service/core/workflow/dispatch/agent/runTool/toolChoice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import { getNanoid, sliceStrStartEnd } from '@fastgpt/global/common/string/tools
import { toolValueTypeList } from '@fastgpt/global/core/workflow/constants';
import { WorkflowInteractiveResponseType } from '@fastgpt/global/core/workflow/template/system/interactive/type';
import { ChatItemValueTypeEnum } from '@fastgpt/global/core/chat/constants';
import { getErrText } from '@fastgpt/global/common/error/utils';

type ToolRunResponseType = {
toolRunResponse: DispatchFlowResponse;
toolRunResponse?: DispatchFlowResponse;
toolMsgParams: ChatCompletionToolMessageParam;
}[];

Expand Down Expand Up @@ -344,59 +345,87 @@ export const runToolWithToolChoice = async (
return Promise.reject(getEmptyResponseTip());
}

// Run the selected tool by LLM.
const toolsRunResponse = (
await Promise.all(
toolCalls.map(async (tool) => {
const toolNode = toolNodes.find((item) => item.nodeId === tool.function?.name);
/* Run the selected tool by LLM.
Since only reference parameters are passed, if the same tool is run in parallel, it will get the same run parameters
*/
const toolsRunResponse: ToolRunResponseType = [];
for await (const tool of toolCalls) {
try {
const toolNode = toolNodes.find((item) => item.nodeId === tool.function?.name);

if (!toolNode) continue;

const startParams = (() => {
try {
return json5.parse(tool.function.arguments);
} catch (error) {
return {};
}
})();

if (!toolNode) return;
initToolNodes(runtimeNodes, [toolNode.nodeId], startParams);
const toolRunResponse = await dispatchWorkFlow({
...workflowProps,
isToolCall: true
});

const startParams = (() => {
try {
return json5.parse(tool.function.arguments);
} catch (error) {
return {};
}
})();
const stringToolResponse = formatToolResponse(toolRunResponse.toolResponses);

initToolNodes(runtimeNodes, [toolNode.nodeId], startParams);
const toolRunResponse = await dispatchWorkFlow({
...workflowProps,
isToolCall: true
});
const toolMsgParams: ChatCompletionToolMessageParam = {
tool_call_id: tool.id,
role: ChatCompletionRequestMessageRoleEnum.Tool,
name: tool.function.name,
content: stringToolResponse
};

const stringToolResponse = formatToolResponse(toolRunResponse.toolResponses);
workflowStreamResponse?.({
event: SseResponseEventEnum.toolResponse,
data: {
tool: {
id: tool.id,
toolName: '',
toolAvatar: '',
params: '',
response: sliceStrStartEnd(stringToolResponse, 5000, 5000)
}
}
});

const toolMsgParams: ChatCompletionToolMessageParam = {
toolsRunResponse.push({
toolRunResponse,
toolMsgParams
});
} catch (error) {
const err = getErrText(error);
workflowStreamResponse?.({
event: SseResponseEventEnum.toolResponse,
data: {
tool: {
id: tool.id,
toolName: '',
toolAvatar: '',
params: '',
response: sliceStrStartEnd(err, 5000, 5000)
}
}
});

toolsRunResponse.push({
toolRunResponse: undefined,
toolMsgParams: {
tool_call_id: tool.id,
role: ChatCompletionRequestMessageRoleEnum.Tool,
name: tool.function.name,
content: stringToolResponse
};

workflowStreamResponse?.({
event: SseResponseEventEnum.toolResponse,
data: {
tool: {
id: tool.id,
toolName: '',
toolAvatar: '',
params: '',
response: sliceStrStartEnd(stringToolResponse, 5000, 5000)
}
}
});

return {
toolRunResponse,
toolMsgParams
};
})
)
).filter(Boolean) as ToolRunResponseType;
content: sliceStrStartEnd(err, 5000, 5000)
}
});
}
}

const flatToolsResponseData = toolsRunResponse.map((item) => item.toolRunResponse).flat();
const flatToolsResponseData = toolsRunResponse
.map((item) => item.toolRunResponse)
.flat()
.filter(Boolean) as DispatchFlowResponse[];
// concat tool responses
const dispatchFlowResponse = response
? response.dispatchFlowResponse.concat(flatToolsResponseData)
Expand Down Expand Up @@ -434,22 +463,22 @@ export const runToolWithToolChoice = async (
const outputTokens = await countGptMessagesTokens(assistantToolMsgParams);

/*
...
user
assistant: tool data
tool: tool response
*/
...
user
assistant: tool data
tool: tool response
*/
const completeMessages = [
...concatToolMessages,
...toolsRunResponse.map((item) => item?.toolMsgParams)
];

/*
Get tool node assistant response
history assistant
current tool assistant
tool child assistant
*/
Get tool node assistant response
history assistant
current tool assistant
tool child assistant
*/
const toolNodeAssistant = GPTMessages2Chats([
...assistantToolMsgParams,
...toolsRunResponse.map((item) => item?.toolMsgParams)
Expand Down Expand Up @@ -478,12 +507,12 @@ export const runToolWithToolChoice = async (
);
// Check interactive response(Only 1 interaction is reserved)
const workflowInteractiveResponseItem = toolsRunResponse.find(
(item) => item.toolRunResponse.workflowInteractiveResponse
(item) => item.toolRunResponse?.workflowInteractiveResponse
);
if (hasStopSignal || workflowInteractiveResponseItem) {
// Get interactive tool data
const workflowInteractiveResponse =
workflowInteractiveResponseItem?.toolRunResponse.workflowInteractiveResponse;
workflowInteractiveResponseItem?.toolRunResponse?.workflowInteractiveResponse;

// Flashback traverses completeMessages, intercepting messages that know the first user
const firstUserIndex = completeMessages.findLastIndex((item) => item.role === 'user');
Expand Down
7 changes: 4 additions & 3 deletions packages/service/core/workflow/dispatch/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import { dispatchLoopEnd } from './loop/runLoopEnd';
import { dispatchLoopStart } from './loop/runLoopStart';
import { dispatchFormInput } from './interactive/formInput';
import { dispatchToolParams } from './agent/runTool/toolParams';
import { getErrText } from '@fastgpt/global/common/error/utils';

const callbackMap: Record<FlowNodeTypeEnum, Function> = {
[FlowNodeTypeEnum.workflowStart]: dispatchWorkflowStart,
Expand Down Expand Up @@ -231,9 +232,7 @@ export async function dispatchWorkFlow(data: Props): Promise<DispatchFlowRespons

if (toolResponses !== undefined) {
if (Array.isArray(toolResponses) && toolResponses.length === 0) return;
if (typeof toolResponses === 'object' && Object.keys(toolResponses).length === 0) {
return;
}
if (typeof toolResponses === 'object' && Object.keys(toolResponses).length === 0) return;
toolRunResponse = toolResponses;
}

Expand Down Expand Up @@ -565,6 +564,8 @@ export async function dispatchWorkFlow(data: Props): Promise<DispatchFlowRespons
const targetEdges = runtimeEdges.filter((item) => item.source === node.nodeId);
const skipHandleIds = targetEdges.map((item) => item.sourceHandle);

toolRunResponse = getErrText(error);

// Skip all edges and return error
return {
[DispatchNodeResponseKeyEnum.nodeResponse]: {
Expand Down

0 comments on commit b75e807

Please sign in to comment.