Skip to content

Commit

Permalink
Merge pull request #307 from ag2ai/Tool_executor
Browse files Browse the repository at this point in the history
[CaptainAgent] Add executor injected with tools.
  • Loading branch information
LeoLjl authored Dec 28, 2024
2 parents 6cc0a81 + a1a1a22 commit 57f229d
Showing 1 changed file with 132 additions and 0 deletions.
132 changes: 132 additions & 0 deletions autogen/agentchat/contrib/tool_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import contextlib
import functools
import importlib.util
import inspect
import io
import os
import traceback
from hashlib import md5
from pathlib import Path
from textwrap import dedent, indent
from typing import List, Optional, Union

import pandas as pd
from sentence_transformers import SentenceTransformer, util

from autogen import AssistantAgent, UserProxyAgent
from autogen.coding import LocalCommandLineCodeExecutor
from autogen.coding.base import CodeBlock, CodeResult
from autogen.function_utils import load_basemodels_if_needed
from autogen.tools import Tool


class ToolBuilder:
Expand Down Expand Up @@ -77,6 +87,108 @@ def bind_user_proxy(self, agent: UserProxyAgent, tool_root: str):
return updated_user_proxy


class LocalExecutorWithTools:
"""
An executor that executes code blocks with injected tools. In this executor, the func within the tools can be called directly without declaring in the code block.
For example, for a tool converted from langchain, the relevant functions can be called directly.
```python
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from autogen.interop import Interoperability
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=3000)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
interop = Interoperability()
ag2_tool = interop.convert_tool(tool=langchain_tool, type="langchain")
# `ag2_tool.name` is wikipedia
local_executor = LocalExecutorWithTools(tools=[ag2_tool], work_dir='./')
code = '''
result = wikipedia(tool_input={"query":"Christmas"})
print(result)
'''
print(
local_executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="python", code=code),
]
)
)
```
In this case, the `wikipedia` function can be called directly in the code block. This hides the complexity of the tool.
Args:
tools: The tools to inject into the code execution environment. Default is an empty list.
work_dir: The working directory for the code execution. Default is the current directory.
"""

def __init__(self, tools: Optional[List[Tool]] = None, work_dir: Union[Path, str] = Path(".")):
self.tools = tools if tools is not None else []
self.work_dir = work_dir
if not os.path.exists(work_dir):
os.makedirs(work_dir, exist_ok=True)

def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
"""Execute code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CodeResult: The result of the code execution.
"""
logs_all = ""
exit_code = 0 # Success code
code_file = None # Path to the first saved codeblock content

for idx, code_block in enumerate(code_blocks):
code = code_block.code
code_hash = md5(code.encode()).hexdigest()
filename = f"tmp_code_{code_hash}.py"
filepath = os.path.join(self.work_dir, filename)
# Save code content to file
with open(filepath, "w", encoding="utf-8") as f:
f.write(code)

if idx == 0:
code_file = filepath

# Create a new execution environment
execution_env = {}
# Inject the tools
for tool in self.tools:
execution_env[tool.name] = _wrap_function(tool.func)

# Prepare to capture stdout and stderr
stdout = io.StringIO()
stderr = io.StringIO()

# Execute the code block
try:
# Redirect stdout and stderr
with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
# Exec the code in the execution environment
exec(code, execution_env)
except Exception:
# Capture exception traceback
tb = traceback.format_exc()
stderr.write(tb)
exit_code = 1 # Non-zero exit code indicates failure

# Collect outputs
stdout_content = stdout.getvalue()
stderr_content = stderr.getvalue()
logs_all += stdout_content + stderr_content

return CodeResult(exit_code=exit_code, output=logs_all, code_file=code_file)

def restart(self):
"""Restart the code executor. Since this executor is stateless, no action is needed."""
pass


def get_full_tool_description(py_file):
"""
Retrieves the function signature for a given Python file.
Expand All @@ -100,6 +212,26 @@ def get_full_tool_description(py_file):
raise ValueError(f"Function {function_name} not found in {py_file}")


def _wrap_function(func):
"""Wrap the function to dump the return value to json.
Handles both sync and async functions.
Args:
func: the function to be wrapped.
Returns:
The wrapped function.
"""

@load_basemodels_if_needed
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

return _wrapped_func


def find_callables(directory):
"""
Find all callable objects defined in Python files within the specified directory.
Expand Down

0 comments on commit 57f229d

Please sign in to comment.