From 4d9fb9ced45107272f017b59561f06e27072f9f2 Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Wed, 4 Dec 2024 19:52:25 +0100 Subject: [PATCH 1/2] Add the `add_tool()`, `remove_tool()` and `remove_all_tools()` methods for `AssistantAgent` --- .../agents/_assistant_agent.py | 207 +++++++++++++++--- .../tests/test_assistant_agent.py | 75 +++++++ 2 files changed, 255 insertions(+), 27 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index ef9ecb2a00c..eadf69a14c8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -14,6 +14,7 @@ ) from autogen_core import CancellationToken, FunctionCall +from autogen_core.components.models import LLMMessage from autogen_core.model_context import ( ChatCompletionContext, UnboundedChatCompletionContext, @@ -246,24 +247,11 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] - if tools is not None: - if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") - for tool in tools: - if isinstance(tool, Tool): - self._tools.append(tool) - elif callable(tool): - if hasattr(tool, "__doc__") and tool.__doc__ is not None: - description = tool.__doc__ - else: - description = "" - self._tools.append(FunctionTool(tool, description=description)) - else: - raise ValueError(f"Unsupported tool type: {type(tool)}") - # Check if tool names are unique. - tool_names = [tool.name for tool in self._tools] - if len(tool_names) != len(set(tool_names)): - raise ValueError(f"Tool names must be unique: {tool_names}") + self._model_context: List[LLMMessage] = [] + self._reflect_on_tool_use = reflect_on_tool_use + self._tool_call_summary_format = tool_call_summary_format + self._is_running = False + # Handoff tools. self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} @@ -273,26 +261,191 @@ def __init__( for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) + if handoff.name in self._handoffs: + raise ValueError(f"Handoff name {handoff.name} already exists.") if isinstance(handoff, HandoffBase): self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: raise ValueError(f"Unsupported handoff type: {type(handoff)}") - # Check if handoff tool names are unique. - handoff_tool_names = [tool.name for tool in self._handoff_tools] - if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") - # Check if handoff tool names not in tool names. - if any(name in tool_names for name in handoff_tool_names): - raise ValueError( - f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" - ) + if tools is not None: + for tool in tools: + self.add_tool(tool) + if not model_context: self._model_context = UnboundedChatCompletionContext() self._reflect_on_tool_use = reflect_on_tool_use self._tool_call_summary_format = tool_call_summary_format self._is_running = False + def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None: + """ + Adds a new tool to the assistant agent. + + The tool can be either an instance of the `Tool` class, or a callable function. If the tool is a callable + function, a `FunctionTool` instance will be created with the function and its docstring as the description. + + The tool name must be unique among all the tools and handoffs added to the agent. If the model does not support + function calling, an error will be raised. + + Args: + tool (Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]): The tool to add. + + Raises: + ValueError: If the tool name is not unique. + ValueError: If the tool name is already used by a handoff. + ValueError: If the tool has an unsupported type. + ValueError: If the model does not support function calling. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + """ + new_tool = None + if self._model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") + if isinstance(tool, Tool): + new_tool = tool + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + new_tool = FunctionTool(tool, description=description) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + if any(tool.name == new_tool.name for tool in self._tools): + raise ValueError(f"Tool names must be unique: {new_tool.name}") + # Check if handoff tool names not in tool names. + handoff_tool_names = [handoff.name for handoff in self._handoffs.values()] + if new_tool.name in handoff_tool_names: + raise ValueError( + f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; " + f"tool names: {new_tool.name}" + ) + self._tools.append(new_tool) + + def remove_all_tools(self) -> None: + """ + Remove all tools. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + agent.remove_all_tools() + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + + """ + self._tools.clear() + + def remove_tool(self, tool_name: str) -> None: + """ + Remove a tool by name. + + Args: + tool_name (str): The name of the tool to remove. + + Raises: + ValueError: If the tool name is not found. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + agent.remove_tool("get_current_time") + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + """ + for tool in self._tools: + if tool.name == tool_name: + self._tools.remove(tool) + return + raise ValueError(f"Tool {tool_name} not found.") + @property def produced_message_types(self) -> List[type[ChatMessage]]: """The types of messages that the assistant agent produces.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9065d513918..bd33702a6f4 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -468,3 +468,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: else: assert message == result.messages[index] index += 1 + + +def test_tool_management(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test function to be used as a tool + def sample_tool() -> str: + return "sample result" + + # Test adding a tool + tool = FunctionTool(sample_tool, description="Sample tool") + agent.add_tool(tool) + assert len(agent._tools) == 1 + + # Test adding duplicate tool + with pytest.raises(ValueError, match="Tool names must be unique"): + agent.add_tool(tool) + + # Test tool collision with handoff + agent_with_handoff = AssistantAgent( + name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")] + ) + + conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool") + with pytest.raises(ValueError, match="Handoff names must be unique from tool names"): + agent_with_handoff.add_tool(conflicting_tool) + + # Test removing a tool + agent.remove_tool(tool.name) + assert len(agent._tools) == 0 + + # Test removing non-existent tool + with pytest.raises(ValueError, match="Tool non_existent_tool not found"): + agent.remove_tool("non_existent_tool") + + # Test removing all tools + agent.add_tool(tool) + assert len(agent._tools) == 1 + agent.remove_all_tools() + assert len(agent._tools) == 0 + + # Test idempotency of remove_all_tools + agent.remove_all_tools() + assert len(agent._tools) == 0 + + +def test_callable_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding a callable directly + def documented_tool() -> str: + """This is a documented tool""" + return "result" + + agent.add_tool(documented_tool) + assert len(agent._tools) == 1 + assert agent._tools[0].description == "This is a documented tool" + + # Test adding async callable + async def async_tool() -> str: + return "async result" + + agent.add_tool(async_tool) + assert len(agent._tools) == 2 + + +def test_invalid_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding invalid tool type + with pytest.raises(ValueError, match="Unsupported tool type"): + agent.add_tool("not a tool") From aaa70d23a5c2bbf7458c8407ed7e7af2a362bcaf Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 27 Dec 2024 10:43:55 -0800 Subject: [PATCH 2/2] Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py --- .../src/autogen_agentchat/agents/_assistant_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index eadf69a14c8..f4c0510c0da 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -283,7 +283,7 @@ def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any Adds a new tool to the assistant agent. The tool can be either an instance of the `Tool` class, or a callable function. If the tool is a callable - function, a `FunctionTool` instance will be created with the function and its docstring as the description. + function, a :class:`~autogen_core.tools.FunctionTool` instance will be created with the function and its docstring as the description. The tool name must be unique among all the tools and handoffs added to the agent. If the model does not support function calling, an error will be raised.