From d3cc09a19fed4bc8ac0bf7b48bea6afa44b52467 Mon Sep 17 00:00:00 2001 From: rjambrecic <32619626+rjambrecic@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:10:08 +0100 Subject: [PATCH] Fix _string_metadata_to_description_field function for python version < 3.11 (#535) * Test Websurfer with python 3.9-3.13 in CI * CI update * CI update * Fix _string_metadata_to_description_field WIP * Fix _string_metadata_to_description_field for optional parameters * Cleanup * CI testing * Fix tests * Refactoring * Add comments --- .github/workflows/contrib-graph-rag-tests.yml | 2 +- .github/workflows/contrib-test.yml | 4 +-- autogen/tools/dependency_injection.py | 10 +++++- autogen/tools/function_utils.py | 16 ++++----- test/agentchat/test_dependancy_injection.py | 24 ++++++++++++-- test/tools/test_dependency_injection.py | 33 +++++++++---------- test/tools/test_function_utils.py | 7 ++++ 7 files changed, 63 insertions(+), 33 deletions(-) diff --git a/.github/workflows/contrib-graph-rag-tests.yml b/.github/workflows/contrib-graph-rag-tests.yml index f39c8543d..6ef48a88c 100644 --- a/.github/workflows/contrib-graph-rag-tests.yml +++ b/.github/workflows/contrib-graph-rag-tests.yml @@ -9,7 +9,7 @@ on: paths: - "autogen/agentchat/contrib/graph_rag/**" - "test/agentchat/contrib/graph_rag/**" - - ".github/workflows/contrib-tests.yml" + - ".github/workflows/contrib-test.yml" - "pyproject.toml" concurrency: diff --git a/.github/workflows/contrib-test.yml b/.github/workflows/contrib-test.yml index ba93a8363..56e6cc451 100644 --- a/.github/workflows/contrib-test.yml +++ b/.github/workflows/contrib-test.yml @@ -11,7 +11,7 @@ on: - "test/agentchat/contrib/**" - "test/test_browser_utils.py" - "test/test_retrieve_utils.py" - - ".github/workflows/contrib-tests.yml" + - ".github/workflows/contrib-test.yml" - "pyproject.toml" concurrency: @@ -249,7 +249,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 diff --git a/autogen/tools/dependency_injection.py b/autogen/tools/dependency_injection.py index 9af9f42c0..3cfe72e5d 100644 --- a/autogen/tools/dependency_injection.py +++ b/autogen/tools/dependency_injection.py @@ -157,11 +157,19 @@ def _string_metadata_to_description_field(func: Callable[..., Any]) -> Callable[ type_hints = get_type_hints(func, include_extras=True) for _, annotation in type_hints.items(): + # Check if the annotation itself has metadata (using __metadata__) if hasattr(annotation, "__metadata__"): metadata = annotation.__metadata__ if metadata and isinstance(metadata[0], str): - # Replace string metadata with DescriptionField + # Replace string metadata with Field annotation.__metadata__ = (Field(description=metadata[0]),) + # For Python < 3.11, annotations like `Optional` are stored as `Union`, so metadata + # would be in the first element of __args__ (e.g., `__args__[0]` for `int` in `Optional[int]`) + elif hasattr(annotation, "__args__") and hasattr(annotation.__args__[0], "__metadata__"): + metadata = annotation.__args__[0].__metadata__ + if metadata and isinstance(metadata[0], str): + # Replace string metadata with Field + annotation.__args__[0].__metadata__ = (Field(description=metadata[0]),) return func diff --git a/autogen/tools/function_utils.py b/autogen/tools/function_utils.py index 7d87f28eb..a0f1ead74 100644 --- a/autogen/tools/function_utils.py +++ b/autogen/tools/function_utils.py @@ -134,17 +134,15 @@ def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> """ def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str: + if not hasattr(v, "__metadata__"): + return k + # handles Annotated - if hasattr(v, "__metadata__"): - retval = v.__metadata__[0] - if isinstance(retval, AG2Field): - return retval.description # type: ignore[return-value] - else: - raise ValueError( - f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}" - ) + retval = v.__metadata__[0] + if isinstance(retval, AG2Field): + return retval.description # type: ignore[return-value] else: - return k + raise ValueError(f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}") schema = type2schema(v) if k in default_values: diff --git a/test/agentchat/test_dependancy_injection.py b/test/agentchat/test_dependancy_injection.py index 19e726b6b..c69b92e8d 100644 --- a/test/agentchat/test_dependancy_injection.py +++ b/test/agentchat/test_dependancy_injection.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Annotated, Any, Callable +from typing import Annotated, Any, Callable, Optional from unittest.mock import MagicMock import pytest @@ -23,6 +23,7 @@ def f_with_annotated( ctx: Annotated[MyContext, Depends(MyContext(b=2))], chat_ctx: Annotated[ChatContext, "Chat context"], c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: assert isinstance(chat_ctx, ChatContext) return a + ctx.b + c @@ -33,6 +34,7 @@ async def f_with_annotated_async( ctx: Annotated[MyContext, Depends(MyContext(b=2))], chat_ctx: ChatContext, c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: assert isinstance(chat_ctx, ChatContext) return a + ctx.b + c @@ -43,6 +45,7 @@ def f_without_annotated( chat_ctx: ChatContext, ctx: MyContext = Depends(MyContext(b=3)), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + c @@ -51,6 +54,7 @@ async def f_without_annotated_async( a: int, ctx: MyContext = Depends(MyContext(b=3)), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + c @@ -59,6 +63,7 @@ def f_with_annotated_and_depends( a: int, ctx: MyContext = MyContext(b=4), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + c @@ -67,6 +72,7 @@ async def f_with_annotated_and_depends_async( a: int, ctx: MyContext = MyContext(b=4), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + c @@ -76,6 +82,7 @@ def f_with_multiple_depends( ctx: Annotated[MyContext, Depends(MyContext(b=2))], ctx2: Annotated[MyContext, Depends(MyContext(b=3))], c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + ctx2.b + c @@ -85,6 +92,7 @@ async def f_with_multiple_depends_async( ctx: Annotated[MyContext, Depends(MyContext(b=2))], ctx2: Annotated[MyContext, Depends(MyContext(b=3))], c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx.b + ctx2.b + c @@ -93,6 +101,7 @@ def f_wihout_base_context( a: int, ctx: Annotated[int, Depends(lambda a: a + 2)], c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx + c @@ -101,6 +110,7 @@ async def f_wihout_base_context_async( a: int, ctx: Annotated[int, Depends(lambda a: a + 2)], c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx + c @@ -109,6 +119,7 @@ def f_with_default_depends( a: int, ctx: int = Depends(lambda a: a + 2), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx + c @@ -117,6 +128,7 @@ async def f_with_default_depends_async( a: int, ctx: int = Depends(lambda a: a + 2), c: Annotated[int, "c description"] = 3, + d: Annotated[Optional[int], "d description"] = None, ) -> int: return a + ctx + c @@ -135,6 +147,11 @@ def expected_tools(self) -> list[dict[str, Any]]: "properties": { "a": {"type": "integer", "description": "a"}, "c": {"type": "integer", "description": "c description", "default": 3}, + "d": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "description": "d description", + "default": None, + }, }, "required": ["a"], }, @@ -224,7 +241,10 @@ async def login( @user_proxy.register_for_execution() @agent.register_for_llm(description="Login function") - def login(user: Annotated[UserContext, Depends(user)]) -> str: + def login( + user: Annotated[UserContext, Depends(user)], + additional_notes: Annotated[Optional[str], "Additional notes"] = None, + ) -> str: return _login(user) user_proxy.initiate_chat(agent, message="Please login", max_turns=2) diff --git a/test/tools/test_dependency_injection.py b/test/tools/test_dependency_injection.py index 303498b8e..21a405a9b 100644 --- a/test/tools/test_dependency_injection.py +++ b/test/tools/test_dependency_injection.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Annotated, Callable, get_type_hints +import sys +from typing import Annotated, Callable, Optional, get_type_hints import pytest from pydantic import BaseModel @@ -189,28 +190,24 @@ def test_remove_injected_params_from_signature(self, test_func: Callable[..., in def test_string_metadata_to_description_field() -> None: - def f(a: int, b: Annotated[int, "b description"]) -> int: + def f( + a: int, + b: Annotated[int, "b description"], + c: Annotated[Optional[int], "c description"] = None, + ) -> int: return a + b - type_hints = get_type_hints(f, include_extras=True) - - params_with_string_metadata = [] - for param, annotation in type_hints.items(): - if hasattr(annotation, "__metadata__"): - metadata = annotation.__metadata__ - if metadata and isinstance(metadata[0], str): - params_with_string_metadata.append(param) - - assert params_with_string_metadata == ["b"] - f = _string_metadata_to_description_field(f) type_hints = get_type_hints(f, include_extras=True) - for param, annotation in type_hints.items(): - if hasattr(annotation, "__metadata__"): - metadata = annotation.__metadata__ - if metadata and isinstance(metadata[0], str): - raise AssertionError("The string metadata should have been replaced with Pydantic's Field") field_info = type_hints["b"].__metadata__[0] assert isinstance(field_info, Field) assert field_info.description == "b description" + + if sys.version_info < (3, 11): + field_info = type_hints["c"].__args__[0].__metadata__[0] + else: + field_info = type_hints["c"].__metadata__[0] + + assert isinstance(field_info, Field) + assert field_info.description == "c description" diff --git a/test/tools/test_function_utils.py b/test/tools/test_function_utils.py index 708bf2185..c614bf2d0 100644 --- a/test/tools/test_function_utils.py +++ b/test/tools/test_function_utils.py @@ -84,6 +84,13 @@ def test_get_parameter_json_schema() -> None: "description": "parameter a", "default": "3.14", } + assert get_parameter_json_schema( + "d", Annotated[Optional[str], AG2Field(description="parameter d")], {"d": None} + ) == { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "description": "parameter d", + } class B(BaseModel): b: float