From 0ddd5460177f5560df18a347a692a2a424993d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Wed, 7 Aug 2024 13:17:52 +0200 Subject: [PATCH] fix: added auto-register of tools' block types, changed block args to list[str], forked append out of save tool, refactored/improved tmux terminal tool --- gptme/commands.py | 5 +- gptme/tools/__init__.py | 75 +++++------------- gptme/tools/base.py | 2 +- gptme/tools/patch.py | 7 +- gptme/tools/save.py | 82 ++++++++++++++----- gptme/tools/shell.py | 7 +- gptme/tools/terminal.py | 172 ++++++++++++++++++---------------------- 7 files changed, 170 insertions(+), 180 deletions(-) diff --git a/gptme/commands.py b/gptme/commands.py index 9f079dc5..36408556 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -84,10 +84,11 @@ def handle_cmd( name, *args = re.split(r"[\n\s]", cmd) full_args = cmd.split(" ", 1)[1] if " " in cmd else "" match name: + # TODO: rewrite to auto-register tools using block_types case "bash" | "sh" | "shell": - yield from execute_shell(full_args, ask=not no_confirm) + yield from execute_shell(full_args, ask=not no_confirm, args=[]) case "python" | "py": - yield from execute_python(full_args, ask=not no_confirm) + yield from execute_python(full_args, ask=not no_confirm, args=[]) case "log": log.undo(1, quiet=True) log.print(show_hidden="--hidden" in args) diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index 391f1511..3e71e125 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -7,12 +7,10 @@ from .base import ToolSpec from .browser import has_browser_tool from .browser import tool as browser_tool -from .patch import execute_patch from .patch import tool as patch_tool from .python import execute_python, register_function from .python import tool as python_tool -from .save import execute_save -from .save import tool as save_tool +from .save import execute_save, tool_append, tool_save from .shell import execute_shell from .shell import tool as shell_tool from .subagent import tool as subagent_tool @@ -35,7 +33,8 @@ all_tools: list[ToolSpec] = [ - save_tool, + tool_save, + tool_append, patch_tool, python_tool, shell_tool, @@ -48,7 +47,7 @@ @dataclass class ToolUse: tool: str - args: dict[str, str] + args: list[str] content: str def execute(self, ask: bool) -> Generator[Message, None, None]: @@ -88,7 +87,6 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]: # get all markdown code blocks for codeblock in get_codeblocks(msg.content): try: - # yield from execute_codeblock(codeblock, ask) if is_supported_codeblock(codeblock): yield from codeblock_to_tooluse(codeblock).execute(ask) else: @@ -110,22 +108,10 @@ def codeblock_to_tooluse(codeblock: str) -> ToolUse: """Parses a codeblock into a ToolUse. Codeblock must be a supported type.""" lang_or_fn = codeblock.splitlines()[0].strip() codeblock_content = codeblock[len(lang_or_fn) :] - - # the first word is the command, the rest are arguments - # if the first word contains a dot or slash, it is a filename - cmd = lang_or_fn.split(" ")[0] - is_filename = "." in cmd or "/" in cmd - if tool := get_tool_for_codeblock(lang_or_fn): - return ToolUse(tool.name, {}, codeblock_content) - elif lang_or_fn.startswith("patch "): - fn = lang_or_fn[len("patch ") :] - return ToolUse("patch", {"file": fn}, codeblock_content) - elif lang_or_fn.startswith("append "): - fn = lang_or_fn[len("append ") :] - return ToolUse("save", {"file": fn, "append": "true"}, codeblock_content) - elif is_filename: - return ToolUse("save", {"file": lang_or_fn}, codeblock_content) + # NOTE: special case + args = lang_or_fn.split(" ")[1:] if tool.name != "save" else [lang_or_fn] + return ToolUse(tool.name, args, codeblock_content) else: assert not is_supported_codeblock(codeblock) raise ValueError( @@ -136,31 +122,12 @@ def codeblock_to_tooluse(codeblock: str) -> ToolUse: def execute_codeblock(codeblock: str, ask: bool) -> Generator[Message, None, None]: """Executes a codeblock and returns the output.""" lang_or_fn = codeblock.splitlines()[0].strip() - codeblock_content = codeblock[len(lang_or_fn) :] - - # the first word is the command, the rest are arguments - # if the first word contains a dot or slash, it is a filename - cmd = lang_or_fn.split(" ")[0] - is_filename = "." in cmd or "/" in cmd - if tool := get_tool_for_codeblock(lang_or_fn): - assert tool.execute - yield from tool.execute(codeblock_content, ask=ask, args={}) - elif lang_or_fn.startswith("patch "): - fn = lang_or_fn[len("patch ") :] - yield from execute_patch(f"```{codeblock}```", ask, {"file": fn}) - elif lang_or_fn.startswith("append "): - fn = lang_or_fn[len("append ") :] - yield from execute_save( - codeblock_content, ask, args={"file": fn, "append": "true"} - ) - elif is_filename: - yield from execute_save(codeblock_content, ask, args={"file": lang_or_fn}) - else: - assert not is_supported_codeblock(codeblock) - logger.debug( - f"Unknown codeblock type '{lang_or_fn}', neither supported language or filename." - ) + if tool.execute: + args = lang_or_fn.split(" ")[1:] + yield from tool.execute(codeblock, ask, args) + assert not is_supported_codeblock(codeblock) + logger.debug("Unknown codeblock, neither supported language or filename.") # TODO: use this instead of passing around codeblocks as strings (with or without ```) @@ -218,23 +185,20 @@ def is_supported_codeblock(codeblock: str) -> bool: def get_tool_for_codeblock(lang_or_fn: str) -> ToolSpec | None: + block_type = lang_or_fn.split(" ")[0] for tool in loaded_tools: - if lang_or_fn in tool.block_types: + if block_type in tool.block_types: return tool + is_filename = "." in lang_or_fn or "/" in lang_or_fn + if is_filename: + # NOTE: special case + return tool_save return None def is_supported_codeblock_tool(lang_or_fn: str) -> bool: - is_filename = "." in lang_or_fn or "/" in lang_or_fn - if get_tool_for_codeblock(lang_or_fn): return True - elif lang_or_fn.startswith("patch "): - return True - elif lang_or_fn.startswith("append "): - return True - elif is_filename: - return True else: return False @@ -262,7 +226,8 @@ def get_tooluse_xml(content: str) -> Generator[ToolUse, None, None]: root = ElementTree.fromstring(content) for tooluse in root.findall("tool-use"): for child in tooluse: - yield ToolUse(tooluse.tag, child.attrib, child.text or "") + # TODO: this child.attrib.values() thing wont really work + yield ToolUse(tooluse.tag, list(child.attrib.values()), child.text or "") def get_tool(tool_name: str) -> ToolSpec: diff --git a/gptme/tools/base.py b/gptme/tools/base.py index 99b3513b..930a4524 100644 --- a/gptme/tools/base.py +++ b/gptme/tools/base.py @@ -9,7 +9,7 @@ class ExecuteFunc(Protocol): def __call__( - self, code: str, ask: bool, args: None | dict[str, str] = None + self, code: str, ask: bool, args: list[str] ) -> Generator[Message, None, None]: ... diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 6212d07a..8df27f83 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -108,12 +108,12 @@ def apply_file(codeblock, filename): def execute_patch( - codeblock: str, ask: bool, args: dict[str, str] + code: str, ask: bool, args: list[str] ) -> Generator[Message, None, None]: """ Applies the patch. """ - fn = args.get("file") + fn = " ".join(args) assert fn, "No filename provided" if ask: confirm = ask_execute("Apply patch?") @@ -122,7 +122,7 @@ def execute_patch( return try: - apply_file(codeblock, fn) + apply_file(code, fn) yield Message("system", "Patch applied") except (ValueError, FileNotFoundError) as e: yield Message("system", f"Patch failed: {e.args[0]}") @@ -134,4 +134,5 @@ def execute_patch( instructions=instructions, examples=examples, execute=execute_patch, + block_types=["patch"], ) diff --git a/gptme/tools/save.py b/gptme/tools/save.py index 3ff6e95d..604b83e7 100644 --- a/gptme/tools/save.py +++ b/gptme/tools/save.py @@ -21,8 +21,7 @@ from .base import ToolSpec instructions = """ -When you send a message containing Python code (and is not a file block), it will be executed in a stateful environment. -Python will respond with the output of the execution. +To save code to a file, use a code block with the filepath as the language. """.strip() examples = """ @@ -35,40 +34,28 @@ def execute_save( - code: str, ask: bool, args: dict[str, str] + code: str, ask: bool, args: list[str] ) -> Generator[Message, None, None]: """Save the code to a file.""" - fn = args.get("file") + fn = " ".join(args) assert fn, "No filename provided" - append = args.get("append", False) - action = "save" if not append else "append" # strip leading newlines code = code.lstrip("\n") if ask: - confirm = ask_execute(f"{action.capitalize()} to {fn}?") + confirm = ask_execute(f"Save to {fn}?") print() else: confirm = True - print(f"Skipping {action} confirmation.") + print("Skipping confirmation.") if ask and not confirm: # early return - yield Message("system", f"{action.capitalize()} cancelled.") + yield Message("system", "Save cancelled.") return path = Path(fn).expanduser() - if append: - if not path.exists(): - yield Message("system", f"File {fn} doesn't exist, can't append to it.") - return - - with open(path, "a") as f: - f.write(code) - yield Message("system", f"Appended to {fn}") - return - # if the file exists, ask to overwrite if path.exists(): if ask: @@ -103,10 +90,65 @@ def execute_save( yield Message("system", f"Saved to {fn}") -tool = ToolSpec( +def execute_append( + code: str, ask: bool, args: list[str] +) -> Generator[Message, None, None]: + """Append the code to a file.""" + fn = " ".join(args) + assert fn, "No filename provided" + # strip leading newlines + code = code.lstrip("\n") + + if ask: + confirm = ask_execute(f"Append to {fn}?") + print() + else: + confirm = True + print("Skipping append confirmation.") + + if ask and not confirm: + # early return + yield Message("system", "Append cancelled.") + return + + path = Path(fn).expanduser() + + if not path.exists(): + yield Message("system", f"File {fn} doesn't exist, can't append to it.") + return + + with open(path, "a") as f: + f.write(code) + yield Message("system", f"Appended to {fn}") + + +tool_save = ToolSpec( name="save", desc="Allows saving code to a file.", instructions=instructions, examples=examples, execute=execute_save, + block_types=["save"], +) + +instructions_append = """ +To append code to a file, use a code block with the language: append +""".strip() + +examples_append = """ +> User: append a print "Hello world" to hello.py +> Assistant: +```append hello.py +print("Hello world") +``` +> System: Appended to `hello.py`. +""".strip() + +tool_append = ToolSpec( + name="append", + desc="Allows appending code to a file.", + instructions=instructions_append, + examples=examples_append, + execute=execute_append, + block_types=["append"], ) diff --git a/gptme/tools/shell.py b/gptme/tools/shell.py index cf3c186d..c3ab9106 100644 --- a/gptme/tools/shell.py +++ b/gptme/tools/shell.py @@ -219,11 +219,14 @@ def set_shell(shell: ShellSession) -> None: _shell = shell -def execute_shell(cmd: str, ask=True, _=None) -> Generator[Message, None, None]: +def execute_shell( + code: str, ask: bool, args: list[str] +) -> Generator[Message, None, None]: """Executes a shell command and returns the output.""" shell = get_shell() + assert not args - cmd = cmd.strip() + cmd = code.strip() if cmd.startswith("$ "): cmd = cmd[len("$ ") :] diff --git a/gptme/tools/terminal.py b/gptme/tools/terminal.py index 1a98e1e4..68b84906 100644 --- a/gptme/tools/terminal.py +++ b/gptme/tools/terminal.py @@ -45,7 +45,7 @@ import logging import subprocess from time import sleep -from typing import Generator, List +from collections.abc import Generator from ..message import Message from ..util import ask_execute, print_preview @@ -53,91 +53,53 @@ logger = logging.getLogger(__name__) - -class TmuxSession: - """ - session: gpt_0 - window: gpt_0:0 - pane: gpt_0:0.0 - """ - - def __init__(self): - output = subprocess.run( - ["tmux", "list-sessions"], - capture_output=True, - text=True, - ) - assert output.returncode == 0 - self.sessions = [ - session.split(":")[0] for session in output.stdout.split("\n") if session - ] - - def new_session(self, command: str) -> tuple[str, str]: - _max_session_id = 0 - for session in self.sessions: - if session.startswith("gptme_"): - _max_session_id = max(_max_session_id, int(session.split("_")[1])) - session_id = f"gptme_{_max_session_id + 1}" - cmd = ["tmux", "new-session", "-d", "-s", session_id, command] - print(" ".join(cmd)) - result = subprocess.run( - " ".join(cmd), - check=True, - capture_output=True, - text=True, - shell=True, - ) - assert result.returncode == 0 - print(result.stdout, result.stderr) - - # sleep 1s and capture output - sleep(1) - output = self.capture_pane(f"{session_id}") - - self.sessions.append(session_id) - return session_id, output - - def send_keys(self, pane_id: str, keys: str) -> tuple[str, str | None]: - result = subprocess.run( - ["tmux", "send-keys", "-t", pane_id, *keys.split(" ")], - capture_output=True, - text=True, - ) - if result.returncode != 0: - return "", f"Failed to send keys to tmux pane `{pane_id}`: {result.stderr}" - sleep(1) - output = self.capture_pane(pane_id) - return output, None - - def capture_pane(self, pane_id: str) -> str: - result = subprocess.run( - ["tmux", "capture-pane", "-p", "-t", pane_id], - capture_output=True, - text=True, - ) - return result.stdout - - def kill_session(self, session_id: str): - subprocess.run(["tmux", "kill-session", "-t", f"gptme_{session_id}"]) - self.sessions.remove(session_id) - - def list_sessions(self) -> List[str]: - return self.sessions +""" +session: gpt_0 +window: gpt_0:0 +pane: gpt_0:0.0 +""" -_tmux_session = None +def get_sessions() -> list[str]: + output = subprocess.run( + ["tmux", "list-sessions"], + capture_output=True, + text=True, + ) + assert output.returncode == 0 + return [session.split(":")[0] for session in output.stdout.split("\n") if session] -def get_tmux_session() -> TmuxSession: - global _tmux_session - if _tmux_session is None: - _tmux_session = TmuxSession() - return _tmux_session +def _capture_pane(pane_id: str) -> str: + result = subprocess.run( + ["tmux", "capture-pane", "-p", "-t", pane_id], + capture_output=True, + text=True, + ) + return result.stdout def new_session(command: str) -> Message: - tmux = get_tmux_session() - session_id, output = tmux.new_session(command) + _max_session_id = 0 + for session in get_sessions(): + if session.startswith("gptme_"): + _max_session_id = max(_max_session_id, int(session.split("_")[1])) + session_id = f"gptme_{_max_session_id + 1}" + cmd = ["tmux", "new-session", "-d", "-s", session_id, command] + print(" ".join(cmd)) + result = subprocess.run( + " ".join(cmd), + check=True, + capture_output=True, + text=True, + shell=True, + ) + assert result.returncode == 0 + print(result.stdout, result.stderr) + + # sleep 1s and capture output + sleep(1) + output = _capture_pane(f"{session_id}") return Message( "system", f"Created new tmux session with ID {session_id} and started '{command}'.\nOutput:\n```\n{output}\n```", @@ -145,18 +107,24 @@ def new_session(command: str) -> Message: def send_keys(pane_id: str, keys: str) -> Message: - tmux = get_tmux_session() - output, error = tmux.send_keys(pane_id, keys) - if error: - return Message("system", error) + result = subprocess.run( + ["tmux", "send-keys", "-t", pane_id, *keys.split(" ")], + capture_output=True, + text=True, + ) + if result.returncode != 0: + return Message( + "system", f"Failed to send keys to tmux pane `{pane_id}`: {result.stderr}" + ) + sleep(1) + output = _capture_pane(pane_id) return Message( "system", f"Sent '{keys}' to pane `{pane_id}`\nOutput:\n```\n{output}\n```" ) def inspect_pane(pane_id: str) -> Message: - tmux = get_tmux_session() - content = tmux.capture_pane(pane_id) + content = _capture_pane(pane_id) return Message( "system", f"""Pane content: @@ -167,20 +135,31 @@ def inspect_pane(pane_id: str) -> Message: def kill_session(session_id: str) -> Message: - tmux = get_tmux_session() - tmux.kill_session(session_id) + result = subprocess.run( + ["tmux", "kill-session", "-t", f"gptme_{session_id}"], + check=True, + capture_output=True, + text=True, + ) + if result.returncode != 0: + return Message( + "system", + f"Failed to kill tmux session with ID {session_id}: {result.stderr}", + ) return Message("system", f"Killed tmux session with ID {session_id}") def list_sessions() -> Message: - tmux = get_tmux_session() - sessions = tmux.list_sessions() + sessions = get_sessions() return Message("system", f"Active tmux sessions: {sessions}") -def execute_terminal(cmd: str, ask=True, _=None) -> Generator[Message, None, None]: +def execute_terminal( + code: str, ask: bool, args: list[str] +) -> Generator[Message, None, None]: """Executes a terminal command and returns the output.""" - cmd = cmd.strip() + assert not args + cmd = code.strip() if ask: print_preview(f"Terminal command: {cmd}", "sh") @@ -193,17 +172,17 @@ def execute_terminal(cmd: str, ask=True, _=None) -> Generator[Message, None, Non if len(parts) == 1: yield Message("system", "Invalid command. Please provide arguments.") - command, args = parts[0], parts[1] + command, _args = parts[0], parts[1] if command == "new_session": - yield new_session(args) + yield new_session(_args) elif command == "send_keys": - pane_id, keys = args.split(maxsplit=1) + pane_id, keys = _args.split(maxsplit=1) yield send_keys(pane_id, keys) elif command == "inspect_pane": - yield inspect_pane(args) + yield inspect_pane(_args) elif command == "kill_session": - yield kill_session(args) + yield kill_session(_args) elif command == "list_sessions": yield list_sessions() else: @@ -292,7 +271,6 @@ def execute_terminal(cmd: str, ask=True, _=None) -> Generator[Message, None, Non desc="Executes terminal commands in a tmux session for interactive applications.", instructions=instructions, examples=examples, - init=get_tmux_session, execute=execute_terminal, block_types=["terminal"], )