diff --git a/gptme/cli.py b/gptme/cli.py index 09ac063c..0b520c2c 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -117,7 +117,7 @@ @click.option( "--tool-format", "tool_format", - default="markdown", + default=None, help="Tool parsing method. Can be 'markdown', 'xml', 'tool'. (experimental)", ) @click.option( @@ -149,7 +149,7 @@ def main( name: str, model: str | None, tool_allowlist: list[str] | None, - tool_format: ToolFormat, + tool_format: ToolFormat | None, stream: bool, verbose: bool, no_confirm: bool, @@ -187,8 +187,10 @@ def main( config = get_config() - tool_format = tool_format or config.get_env("TOOL_FORMAT") or "markdown" - set_tool_format(tool_format) + selected_tool_format: ToolFormat = ( + tool_format or config.get_env("TOOL_FORMAT") or "markdown" # type: ignore + ) + set_tool_format(selected_tool_format) # early init tools to generate system prompt init_tools(frozenset(tool_allowlist) if tool_allowlist else None) @@ -198,7 +200,7 @@ def main( get_prompt( prompt_system, interactive=interactive, - tool_format=tool_format, + tool_format=selected_tool_format, ) ] @@ -290,7 +292,7 @@ def inject_stdin(prompt_msgs, piped_input: str | None) -> list[Message]: show_hidden, workspace_path, tool_allowlist, - tool_format, + selected_tool_format, ) except RuntimeError as e: logger.error(e) diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index aca2ea10..7549f7e3 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -9,10 +9,11 @@ TypedDict, cast, ) +from collections.abc import Iterable from ..constants import TEMPERATURE, TOP_P from ..message import Message, msgs2dicts -from ..tools.base import Parameter, ToolSpec +from ..tools.base import Parameter, ToolSpec, ToolUse if TYPE_CHECKING: # noreorder @@ -64,9 +65,17 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s ) content = response.content logger.debug(response.usage) - assert content - assert len(content) == 1 - return content[0].text # type: ignore + + parsed_block = [] + for block in content: + if block.type == "text": + parsed_block.append(block.text) + elif block.type == "tool_use": + parsed_block.append(f"\n@{block.name}({block.id}): {block.input}") + else: + logger.warning("Unknown block: %s", str(block)) + + return "\n".join(parsed_block) def stream( @@ -96,7 +105,7 @@ def stream( block = chunk.content_block if isinstance(block, anthropic.types.ToolUseBlock): tool_use = block - yield f"\n@{tool_use.name}: " + yield f"\n@{tool_use.name}({tool_use.id}): " elif isinstance(block, anthropic.types.TextBlock): if block.text: logger.warning("unexpected text block: %s", block.text) @@ -132,8 +141,78 @@ def stream( pass -def _handle_files(message_dicts: list[dict]) -> list[dict]: - return [_process_file(message_dict) for message_dict in message_dicts] +def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: + for message in message_dicts: + # Format tool result as expected by the model + if message["role"] == "system" and "call_id" in message: + modified_message = dict(message) + modified_message["role"] = "user" + modified_message["content"] = [ + { + "type": "tool_result", + "content": modified_message["content"], + "tool_use_id": modified_message.pop("call_id"), + } + ] + yield modified_message + # Find tool_use occurrences and format them as expected + elif message["role"] == "assistant": + modified_message = dict(message) + text = "" + content = [] + + # Some content are text, some are list + if isinstance(message["content"], list): + message_parts = message["content"] + else: + message_parts = [{"type": "text", "text": message["content"]}] + + for message_part in message_parts: + if message_part["type"] != "text": + content.append(message_part) + continue + + # For a message part of type `text`` we try to extract the tool_uses + # We search line by line to stop as soon as we have a tool call + # It makes it easier to split in multiple parts. + for line in message_part["text"].split("\n"): + text += line + "\n" + + tooluses = [ + tooluse + for tooluse in ToolUse.iter_from_content(text) + if tooluse.is_runnable + ] + if not tooluses: + continue + + # At that point we should always have exactly one tooluse + # Because we remove the previous ones as soon as we encounter + # them so we can't have more. + assert len(tooluses) == 1 + tooluse = tooluses[0] + before_tool = text[: tooluse.start] + + if before_tool: + content.append({"type": "text", "text": before_tool}) + + content.append( + { + "type": "tool_use", + "id": tooluse.call_id or "", + "name": tooluse.tool, + "input": tooluse.kwargs or {}, + } + ) + # The text is emptied to start over with the next lines if any. + text = "" + + if content: + modified_message["content"] = content + + yield modified_message + else: + yield message def _process_file(message_dict: dict) -> dict: @@ -219,9 +298,11 @@ def _transform_system_messages( messages = messages.copy() messages.pop(0) - # for any subsequent system messages, transform them into a message + # Convert subsequent system messages into messages, + # unless a `call_id` is present, indicating the tool_format is 'tool'. + # Tool responses are handled separately by _handle_tool. for i, message in enumerate(messages): - if message.role == "system": + if message.role == "system" and message.call_id is None: messages[i] = Message( "user", content=f"{message.content}", @@ -251,7 +332,7 @@ def _transform_system_messages( return messages, system_messages -def parameters2dict(parameters: list[Parameter]) -> dict[str, object]: +def _parameters2dict(parameters: list[Parameter]) -> dict[str, object]: required = [] properties = {} @@ -279,7 +360,7 @@ def _spec2tool( return { "name": name, "description": spec.get_instructions("tool"), - "input_schema": parameters2dict(spec.parameters), + "input_schema": _parameters2dict(spec.parameters), } @@ -315,7 +396,13 @@ def _prepare_messages_for_api( messages, system_messages = _transform_system_messages(messages) # Handle files and convert to dicts - messages_dicts = _handle_files(msgs2dicts(messages)) + messages_dicts = (_process_file(f) for f in msgs2dicts(messages)) + + # Prepare tools + tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + + if tools_dict is not None: + messages_dicts = _handle_tools(messages_dicts) # Apply cache control to optimize performance messages_dicts_new: list[anthropic.types.MessageParam] = [] @@ -352,7 +439,4 @@ def _prepare_messages_for_api( assert isinstance(msgp["content"], list) msgp["content"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore - # Prepare tools - tools_dict = [_spec2tool(tool) for tool in tools] if tools else None - return messages_dicts_new, system_messages, tools_dict diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 7819c666..129c62d6 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -1,5 +1,7 @@ import base64 +import json import logging +from collections.abc import Iterable from collections.abc import Generator from pathlib import Path from typing import TYPE_CHECKING, Any, cast @@ -7,8 +9,8 @@ from ..config import Config from ..constants import TEMPERATURE, TOP_P from ..message import Message, msgs2dicts -from ..tools.base import Parameter, ToolSpec -from .models import Provider, get_model +from ..tools.base import Parameter, ToolSpec, ToolUse +from .models import ModelMeta, Provider, get_model if TYPE_CHECKING: # noreorder @@ -112,16 +114,11 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s # top_p controls diversity, temperature controls randomness assert openai, "LLM not initialized" - is_o1 = model.startswith("o1") - if is_o1: - messages = list(_prep_o1(messages)) - - messages_dicts = handle_files(msgs2dicts(messages)) - _transform_msgs_for_special_provider(messages_dicts) - from openai import NOT_GIVEN # fmt: skip - tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + is_o1 = model.startswith("o1") + + messages_dicts, tools_dict = _prepare_messages_for_api(messages, tools) response = openai.chat.completions.create( model=model, @@ -133,9 +130,19 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s openrouter_headers if "openrouter.ai" in str(openai.base_url) else {} ), ) - content = response.choices[0].message.content - assert content - return content + choice = response.choices[0] + result = [] + if choice.finish_reason == "tool_calls": + for tool_call in choice.message.tool_calls or []: + result.append( + f"@{tool_call.function.name}({tool_call.id}): {tool_call.function.arguments}" + ) + else: + if choice.message.content: + result.append(choice.message.content) + + assert result + return "\n".join(result) def stream( @@ -144,16 +151,11 @@ def stream( assert openai, "LLM not initialized" stop_reason = None - is_o1 = model.startswith("o1") - if is_o1: - messages = list(_prep_o1(messages)) - - messages_dicts = handle_files(msgs2dicts(messages)) - _transform_msgs_for_special_provider(messages_dicts) - from openai import NOT_GIVEN # fmt: skip - tools_dict = [_spec2tool(tool) for tool in tools] if tools else None + is_o1 = model.startswith("o1") + + messages_dicts, tools_dict = _prepare_messages_for_api(messages, tools) for chunk_raw in openai.chat.completions.create( model=model, @@ -198,20 +200,91 @@ def stream( func = tool_call.function if isinstance(func, ChoiceDeltaToolCallFunction): if func.name: - yield f"\n@{func.name}: " + yield f"\n@{func.name}({tool_call.id}): " if func.arguments: yield func.arguments logger.debug(f"Stop reason: {stop_reason}") -def handle_files(msgs: list[dict]) -> list[dict]: - return [_process_file(msg) for msg in msgs] - - -def _process_file(msg: dict) -> dict: +def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: + for message in message_dicts: + # Format tool result as expected by the model + if message["role"] == "system" and "call_id" in message: + modified_message = dict(message) + modified_message["role"] = "tool" + modified_message["tool_call_id"] = modified_message.pop("call_id") + yield modified_message + # Find tool_use occurrences and format them as expected + elif message["role"] == "assistant": + modified_message = dict(message) + text = "" + content = [] + tool_calls = [] + + # Some content are text, some are list + if isinstance(message["content"], list): + message_parts = message["content"] + else: + message_parts = [{"type": "text", "text": message["content"]}] + + for message_part in message_parts: + if message_part["type"] != "text": + content.append(message_part) + continue + + # For a message part of type `text`` we try to extract the tool_uses + # We search line by line to stop as soon as we have a tool call + # It makes it easier to split in multiple parts. + for line in message_part["text"].split("\n"): + text += line + "\n" + + tooluses = [ + tooluse + for tooluse in ToolUse.iter_from_content(text) + if tooluse.is_runnable + ] + if not tooluses: + continue + + # At that point we should always have exactly one tooluse + # Because we remove the previous ones as soon as we encounter + # them so we can't have more. + assert len(tooluses) == 1 + tooluse = tooluses[0] + before_tool = text[: tooluse.start] + + if before_tool.replace("\n", ""): + content.append({"type": "text", "text": before_tool}) + + tool_calls.append( + { + "id": tooluse.call_id or "", + "type": "function", + "function": { + "name": tooluse.tool, + "arguments": json.dumps(tooluse.kwargs or {}), + }, + } + ) + # The text is emptied to start over with the next lines if any. + text = "" + + if content: + modified_message["content"] = content + + if tool_calls: + if not content: + del modified_message["content"] + modified_message["tool_calls"] = tool_calls + + yield modified_message + else: + yield message + + +def _process_file(msg: dict, model: ModelMeta) -> dict: message_content = msg["content"] - model = get_model() if model.provider == "deepseek": # deepseek does not support files return msg @@ -280,15 +353,16 @@ def _process_file(msg: dict) -> dict: return msg -def _transform_msgs_for_special_provider(messages_dicts: list[dict]): - if get_provider() == "groq": +def _transform_msgs_for_special_provider( + messages_dicts: Iterable[dict], model: ModelMeta +): + if model.provider == "groq": # groq needs message.content to be a string - messages_dicts = [ - {**msg, "content": msg["content"][0]["text"]} for msg in messages_dicts - ] + return [{**msg, "content": msg["content"][0]["text"]} for msg in messages_dicts] + return messages_dicts -def parameters2dict(parameters: list[Parameter]) -> dict[str, object]: +def _parameters2dict(parameters: list[Parameter]) -> dict[str, object]: required = [] properties = {} @@ -305,7 +379,7 @@ def parameters2dict(parameters: list[Parameter]) -> dict[str, object]: } -def _spec2tool(spec: ToolSpec) -> "ChatCompletionToolParam": +def _spec2tool(spec: ToolSpec, model: ModelMeta) -> "ChatCompletionToolParam": name = spec.name if spec.block_types: name = spec.block_types[0] @@ -319,16 +393,38 @@ def _spec2tool(spec: ToolSpec) -> "ChatCompletionToolParam": ) description = description[:1024] - provider = get_provider() - if provider in ["openai", "azure", "openrouter", "local"]: + if model.provider in ["openai", "azure", "openrouter", "local"]: return { "type": "function", "function": { "name": name, "description": description, - "parameters": parameters2dict(spec.parameters), + "parameters": _parameters2dict(spec.parameters), # "strict": False, # not supported by OpenRouter }, } else: raise ValueError("Provider doesn't support tools API") + + +def _prepare_messages_for_api( + messages: list[Message], tools: list[ToolSpec] | None +) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]: + model = get_model() + + is_o1 = model.model.startswith("o1") + if is_o1: + messages = list(_prep_o1(messages)) + + messages_dicts: Iterable[dict] = ( + _process_file(msg, model) for msg in msgs2dicts(messages) + ) + + tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None + + if tools_dict is not None: + messages_dicts = _handle_tools(messages_dicts) + + messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model) + + return list(messages_dicts), tools_dict diff --git a/gptme/message.py b/gptme/message.py index e0f920ec..66135eb6 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -45,6 +45,7 @@ class Message: quiet: bool = False timestamp: datetime = field(default_factory=datetime.now) files: list[Path] = field(default_factory=list) + call_id: str | None = None def __post_init__(self): assert isinstance(self.timestamp, datetime) @@ -81,6 +82,8 @@ def to_dict(self, keys=None) -> dict: d["pinned"] = True if self.hide: d["hide"] = True + if self.call_id: + d["call_id"] = self.call_id if keys: return {k: d[k] for k in keys if k in d} return d @@ -142,6 +145,7 @@ def to_toml(self) -> str: {content} """ timestamp = "{self.timestamp.isoformat()}" +call_id = "{self.call_id}" {extra} ''' @@ -164,6 +168,7 @@ def from_toml(cls, toml: str) -> Self: hide=msg.get("hide", False), files=[Path(f) for f in msg.get("files", [])], timestamp=datetime.fromisoformat(msg["timestamp"]), + call_id=msg.get("call_id", None), ) def get_codeblocks(self) -> list[Codeblock]: @@ -297,7 +302,7 @@ def toml_to_msgs(toml: str) -> list[Message]: def msgs2dicts(msgs: list[Message]) -> list[dict]: """Convert a list of Message objects to a list of dicts ready to pass to an LLM.""" - return [msg.to_dict(keys=["role", "content", "files"]) for msg in msgs] + return [msg.to_dict(keys=["role", "content", "files", "call_id"]) for msg in msgs] # Global cache mapping hashes to token counts diff --git a/gptme/tools/base.py b/gptme/tools/base.py index dac3e153..92d18a41 100644 --- a/gptme/tools/base.py +++ b/gptme/tools/base.py @@ -35,7 +35,7 @@ exclusive_mode = False # Match tool name and start of JSON -toolcall_re = re.compile(r"^@(\w+):\s*({.*)", re.M | re.S) +toolcall_re = re.compile(r"^@(\w+)\((\w+)\):\s*({.*)", re.M | re.S) def find_json_end(s: str, start: int) -> int | None: @@ -67,7 +67,7 @@ def find_json_end(s: str, start: int) -> int | None: def extract_json(content: str, match: re.Match) -> str | None: """Extract complete JSON object starting from a regex match""" - json_start = match.start(2) # start of the JSON content + json_start = match.start(3) # start of the JSON content json_end = find_json_end(content, json_start) if json_end is None: return None @@ -272,6 +272,7 @@ class ToolUse: args: list[str] | None content: str | None kwargs: dict[str, str] | None = None + call_id: str | None = None start: int | None = None def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]: @@ -289,9 +290,10 @@ def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]: confirm, ) if isinstance(ex, Generator): - yield from ex + for msg in ex: + yield msg.replace(call_id=self.call_id) else: - yield ex + yield ex.replace(call_id=self.call_id) except Exception as e: # if we are testing, raise the exception logger.exception(e) @@ -358,14 +360,21 @@ def iter_from_content(cls, content: str) -> Generator["ToolUse", None, None]: # check if its a toolcall and extract valid JSON if match := toolcall_re.search(content): tool_name = match.group(1) + call_id = match.group(2) if (json_str := extract_json(content, match)) is not None: try: kwargs = json_repair.loads(json_str) if not isinstance(kwargs, dict): logger.debug(f"JSON repair result is not a dict: {kwargs}") return + start_pos = content.find(f"@{tool_name}(") yield ToolUse( - tool_name, None, None, kwargs=cast(dict[str, str], kwargs) + tool_name, + None, + None, + kwargs=cast(dict[str, str], kwargs), + call_id=call_id, + start=start_pos, ) except json.JSONDecodeError: logger.debug(f"Failed to parse JSON: {json_str}") diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 8187220d..6b4f6764 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -237,7 +237,7 @@ def execute_patch( confirm: ConfirmFunc = lambda _: True, ) -> Generator[Message, None, None]: """Applies the patch.""" - if code is not None and kwargs is not None: + if code is None and kwargs is not None: code = kwargs.get("patch", code) if not code: diff --git a/tests/test_llm_anthropic.py b/tests/test_llm_anthropic.py new file mode 100644 index 00000000..8a5081e5 --- /dev/null +++ b/tests/test_llm_anthropic.py @@ -0,0 +1,161 @@ +from gptme.llm.llm_anthropic import _prepare_messages_for_api +from gptme.message import Message +from gptme.tools import get_tool, init_tools + + +def test_message_conversion(): + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + ] + + messages_dicts, system_messages, tools = _prepare_messages_for_api(messages, None) + + assert tools is None + + assert system_messages == [ + { + "type": "text", + "text": "Initial Message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert list(messages_dicts) == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Project prompt\n\nFirst user prompt", + "cache_control": {"type": "ephemeral"}, + } + ], + } + ] + + +def test_message_conversion_without_tools(): + init_tools(allowlist=frozenset(["save"])) + + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + Message( + role="assistant", + content="\nSomething\n\n```save path.txt\nfile_content\n```", + ), + Message(role="system", content="Saved to toto.txt"), + ] + + messages_dicts, _, _ = _prepare_messages_for_api(messages, None) + + assert messages_dicts == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Project prompt\n\nFirst user prompt", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "\nSomething\n\n```save path.txt\nfile_content\n```", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Saved to toto.txt", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ] + + +def test_message_conversion_with_tools(): + init_tools(allowlist=frozenset(["save"])) + + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + Message( + role="assistant", + content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', + ), + Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + ] + + tool_save = get_tool("save") + + assert tool_save + + messages_dicts, _, tools = _prepare_messages_for_api(messages, [tool_save]) + + assert tools == [ + { + "name": "save", + "description": "Create or overwrite a file with the given content.\n\n" + "The path can be relative to the current directory, or absolute.\n" + "If the current directory changes, the path will be relative to the " + "new directory.", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "The path of the file"}, + "content": {"type": "string", "description": "The content to save"}, + }, + "required": ["path", "content"], + "additionalProperties": False, + }, + } + ] + + assert list(messages_dicts) == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Project prompt\n\nFirst user prompt", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "\nSomething\n\n"}, + { + "type": "tool_use", + "id": "tool_call_id", + "name": "save", + "input": {"path": "path.txt", "content": "file_content"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "content": [{"type": "text", "text": "Saved to toto.txt"}], + "tool_use_id": "tool_call_id", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ] diff --git a/tests/test_llm_openai.py b/tests/test_llm_openai.py new file mode 100644 index 00000000..0d0b0810 --- /dev/null +++ b/tests/test_llm_openai.py @@ -0,0 +1,189 @@ +from gptme.llm.llm_openai import _prepare_messages_for_api +from gptme.llm.models import set_default_model +from gptme.message import Message +from gptme.tools import get_tool, init_tools + + +def test_message_conversion(): + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + ] + + set_default_model("openai/gpt-o4") + + messages_dict, tools_dict = _prepare_messages_for_api(messages, None) + + assert tools_dict is None + assert messages_dict == [ + {"role": "system", "content": [{"type": "text", "text": "Initial Message"}]}, + {"role": "system", "content": [{"type": "text", "text": "Project prompt"}]}, + {"role": "user", "content": [{"type": "text", "text": "First user prompt"}]}, + ] + + +def test_message_conversion_o1(): + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + ] + + set_default_model("openai/o1-mini") + + messages_dict, _ = _prepare_messages_for_api(messages, None) + + assert messages_dict == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "\nInitial Message\n"} + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "\nProject prompt\n"} + ], + }, + {"role": "user", "content": [{"type": "text", "text": "First user prompt"}]}, + ] + + +def test_message_conversion_without_tools(): + init_tools(allowlist=frozenset(["save"])) + + messages = [ + Message(role="system", content="Initial Message", pinned=True, hide=True), + Message(role="system", content="Project prompt", hide=True), + Message(role="user", content="First user prompt"), + Message( + role="assistant", + content="\nSomething\n\n```save path.txt\nfile_content\n```", + ), + Message(role="system", content="Saved to toto.txt"), + ] + + set_default_model("openai/gpt-o4") + + messages_dicts, _ = _prepare_messages_for_api(messages, None) + + assert messages_dicts == [ + {"role": "system", "content": [{"type": "text", "text": "Initial Message"}]}, + {"role": "system", "content": [{"type": "text", "text": "Project prompt"}]}, + {"role": "user", "content": [{"type": "text", "text": "First user prompt"}]}, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "\nSomething\n\n```save path.txt\nfile_content\n```", + } + ], + }, + { + "role": "system", + "content": [{"type": "text", "text": "Saved to toto.txt"}], + }, + ] + + +def test_message_conversion_with_tools(): + # clear_tools() + init_tools(allowlist=frozenset(["save"])) + + messages = [ + Message(role="user", content="First user prompt"), + Message( + role="assistant", + content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', + ), + Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + Message(role="user", content="Second user prompt"), + Message( + role="assistant", + content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', + ), + Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + ] + + set_default_model("openai/gpt-o4") + + tool_save = get_tool("save") + + assert tool_save + + messages_dicts, tools_dict = _prepare_messages_for_api(messages, [tool_save]) + + assert tools_dict == [ + { + "type": "function", + "function": { + "name": "save", + "description": "Create or overwrite a file with the given content.\n\n" + "The path can be relative to the current directory, or absolute.\n" + "If the current directory changes, the path will be relative to the " + "new directory.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path of the file", + }, + "content": { + "type": "string", + "description": "The content to save", + }, + }, + "required": ["path", "content"], + "additionalProperties": False, + }, + }, + } + ] + + assert messages_dicts == [ + {"role": "user", "content": [{"type": "text", "text": "First user prompt"}]}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "\nSomething\n\n"} + ], + "tool_calls": [ + { + "id": "tool_call_id", + "type": "function", + "function": { + "name": "save", + "arguments": '{"path": "path.txt", "content": "file_content"}', + }, + } + ], + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Saved to toto.txt"}], + "tool_call_id": "tool_call_id", + }, + {"role": "user", "content": [{"type": "text", "text": "Second user prompt"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "tool_call_id", + "type": "function", + "function": { + "name": "save", + "arguments": '{"path": "path.txt", "content": "file_content"}', + }, + } + ], + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Saved to toto.txt"}], + "tool_call_id": "tool_call_id", + }, + ] diff --git a/tests/test_tool_use.py b/tests/test_tool_use.py index 4ff2491a..7299626d 100644 --- a/tests/test_tool_use.py +++ b/tests/test_tool_use.py @@ -70,52 +70,52 @@ def test_tool_use_output_patch(tool_format, args, content, kwargs, expected): "content, expected_tool, expected_json", [ ( - '@tool: {"param": "value"}', + '@tool(tool_uid): {"param": "value"}', "tool", '{"param": "value"}', ), ( - '@tool: {"missing": "comma" "key": "value"}', # json_repair can fix this + '@tool(tool_uid): {"missing": "comma" "key": "value"}', # json_repair can fix this "tool", '{"missing": "comma", "key": "value"}', ), ( - "@tool: {invalid json}", # json_repair can handle this + "@tool(tool_uid): {invalid json}", # json_repair can handle this "tool", "{}", ), ( - '@tool: {\n "param": "value"\n}', + '@tool(tool_uid): {\n "param": "value"\n}', "tool", '{\n "param": "value"\n}', ), ( - '@tool: {\n "param": "value with\nnewline",\n "another": "value"\n}', + '@tool(tool_uid): {\n "param": "value with\nnewline",\n "another": "value"\n}', "tool", '{\n "param": "value with\nnewline",\n "another": "value"\n}', ), ( - '@tool: {"param": {"nested": "value"}}', + '@tool(tool_uid): {"param": {"nested": "value"}}', "tool", '{"param": {"nested": "value"}}', ), ( - '@tool: {"param": {"deeply": {"nested": "value"}}}', + '@tool(tool_uid): {"param": {"deeply": {"nested": "value"}}}', "tool", '{"param": {"deeply": {"nested": "value"}}}', ), ( - '@tool: {"text": "a string with } brace"}', + '@tool(tool_uid): {"text": "a string with } brace"}', "tool", '{"text": "a string with } brace"}', ), ( - '@tool: {"text": "a string with \\"quote\\" and } brace"}', + '@tool(tool_uid): {"text": "a string with \\"quote\\" and } brace"}', "tool", '{"text": "a string with \\"quote\\" and } brace"}', ), ( - '@save: {"path": "hello.py", "content": "def main():\n print(\\"Hello, World!\\")\n \nif __name__ == \\"__main__\\":\n main()"}', + '@save(tool_uid): {"path": "hello.py", "content": "def main():\n print(\\"Hello, World!\\")\n \nif __name__ == \\"__main__\\":\n main()"}', "save", '{"path": "hello.py", "content": "def main():\n print(\\"Hello, World!\\")\n \nif __name__ == \\"__main__\\":\n main()"}', ),