Skip to content

Commit

Permalink
Fix _string_metadata_to_description_field function for python version…
Browse files Browse the repository at this point in the history
… < 3.11 (#535)

* Test Websurfer with python 3.9-3.13 in CI

* CI update

* CI update

* Fix _string_metadata_to_description_field WIP

* Fix _string_metadata_to_description_field for optional parameters

* Cleanup

* CI testing

* Fix tests

* Refactoring

* Add comments
  • Loading branch information
rjambrecic authored Jan 20, 2025
1 parent b69245f commit d3cc09a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-graph-rag-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
paths:
- "autogen/agentchat/contrib/graph_rag/**"
- "test/agentchat/contrib/graph_rag/**"
- ".github/workflows/contrib-tests.yml"
- ".github/workflows/contrib-test.yml"
- "pyproject.toml"

concurrency:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/contrib-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
- "test/agentchat/contrib/**"
- "test/test_browser_utils.py"
- "test/test_retrieve_utils.py"
- ".github/workflows/contrib-tests.yml"
- ".github/workflows/contrib-test.yml"
- "pyproject.toml"

concurrency:
Expand Down Expand Up @@ -249,7 +249,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.13"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
Expand Down
10 changes: 9 additions & 1 deletion autogen/tools/dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,19 @@ def _string_metadata_to_description_field(func: Callable[..., Any]) -> Callable[
type_hints = get_type_hints(func, include_extras=True)

for _, annotation in type_hints.items():
# Check if the annotation itself has metadata (using __metadata__)
if hasattr(annotation, "__metadata__"):
metadata = annotation.__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with DescriptionField
# Replace string metadata with Field
annotation.__metadata__ = (Field(description=metadata[0]),)
# For Python < 3.11, annotations like `Optional` are stored as `Union`, so metadata
# would be in the first element of __args__ (e.g., `__args__[0]` for `int` in `Optional[int]`)
elif hasattr(annotation, "__args__") and hasattr(annotation.__args__[0], "__metadata__"):
metadata = annotation.__args__[0].__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with Field
annotation.__args__[0].__metadata__ = (Field(description=metadata[0]),)
return func


Expand Down
16 changes: 7 additions & 9 deletions autogen/tools/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,15 @@ def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) ->
"""

def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str:
if not hasattr(v, "__metadata__"):
return k

# handles Annotated
if hasattr(v, "__metadata__"):
retval = v.__metadata__[0]
if isinstance(retval, AG2Field):
return retval.description # type: ignore[return-value]
else:
raise ValueError(
f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}"
)
retval = v.__metadata__[0]
if isinstance(retval, AG2Field):
return retval.description # type: ignore[return-value]
else:
return k
raise ValueError(f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}")

schema = type2schema(v)
if k in default_values:
Expand Down
24 changes: 22 additions & 2 deletions test/agentchat/test_dependancy_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated, Any, Callable
from typing import Annotated, Any, Callable, Optional
from unittest.mock import MagicMock

import pytest
Expand All @@ -23,6 +23,7 @@ def f_with_annotated(
ctx: Annotated[MyContext, Depends(MyContext(b=2))],
chat_ctx: Annotated[ChatContext, "Chat context"],
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
assert isinstance(chat_ctx, ChatContext)
return a + ctx.b + c
Expand All @@ -33,6 +34,7 @@ async def f_with_annotated_async(
ctx: Annotated[MyContext, Depends(MyContext(b=2))],
chat_ctx: ChatContext,
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
assert isinstance(chat_ctx, ChatContext)
return a + ctx.b + c
Expand All @@ -43,6 +45,7 @@ def f_without_annotated(
chat_ctx: ChatContext,
ctx: MyContext = Depends(MyContext(b=3)),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + c

Expand All @@ -51,6 +54,7 @@ async def f_without_annotated_async(
a: int,
ctx: MyContext = Depends(MyContext(b=3)),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + c

Expand All @@ -59,6 +63,7 @@ def f_with_annotated_and_depends(
a: int,
ctx: MyContext = MyContext(b=4),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + c

Expand All @@ -67,6 +72,7 @@ async def f_with_annotated_and_depends_async(
a: int,
ctx: MyContext = MyContext(b=4),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + c

Expand All @@ -76,6 +82,7 @@ def f_with_multiple_depends(
ctx: Annotated[MyContext, Depends(MyContext(b=2))],
ctx2: Annotated[MyContext, Depends(MyContext(b=3))],
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + ctx2.b + c

Expand All @@ -85,6 +92,7 @@ async def f_with_multiple_depends_async(
ctx: Annotated[MyContext, Depends(MyContext(b=2))],
ctx2: Annotated[MyContext, Depends(MyContext(b=3))],
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx.b + ctx2.b + c

Expand All @@ -93,6 +101,7 @@ def f_wihout_base_context(
a: int,
ctx: Annotated[int, Depends(lambda a: a + 2)],
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx + c

Expand All @@ -101,6 +110,7 @@ async def f_wihout_base_context_async(
a: int,
ctx: Annotated[int, Depends(lambda a: a + 2)],
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx + c

Expand All @@ -109,6 +119,7 @@ def f_with_default_depends(
a: int,
ctx: int = Depends(lambda a: a + 2),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx + c

Expand All @@ -117,6 +128,7 @@ async def f_with_default_depends_async(
a: int,
ctx: int = Depends(lambda a: a + 2),
c: Annotated[int, "c description"] = 3,
d: Annotated[Optional[int], "d description"] = None,
) -> int:
return a + ctx + c

Expand All @@ -135,6 +147,11 @@ def expected_tools(self) -> list[dict[str, Any]]:
"properties": {
"a": {"type": "integer", "description": "a"},
"c": {"type": "integer", "description": "c description", "default": 3},
"d": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"description": "d description",
"default": None,
},
},
"required": ["a"],
},
Expand Down Expand Up @@ -224,7 +241,10 @@ async def login(

@user_proxy.register_for_execution()
@agent.register_for_llm(description="Login function")
def login(user: Annotated[UserContext, Depends(user)]) -> str:
def login(
user: Annotated[UserContext, Depends(user)],
additional_notes: Annotated[Optional[str], "Additional notes"] = None,
) -> str:
return _login(user)

user_proxy.initiate_chat(agent, message="Please login", max_turns=2)
Expand Down
33 changes: 15 additions & 18 deletions test/tools/test_dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Annotated, Callable, get_type_hints
import sys
from typing import Annotated, Callable, Optional, get_type_hints

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -189,28 +190,24 @@ def test_remove_injected_params_from_signature(self, test_func: Callable[..., in


def test_string_metadata_to_description_field() -> None:
def f(a: int, b: Annotated[int, "b description"]) -> int:
def f(
a: int,
b: Annotated[int, "b description"],
c: Annotated[Optional[int], "c description"] = None,
) -> int:
return a + b

type_hints = get_type_hints(f, include_extras=True)

params_with_string_metadata = []
for param, annotation in type_hints.items():
if hasattr(annotation, "__metadata__"):
metadata = annotation.__metadata__
if metadata and isinstance(metadata[0], str):
params_with_string_metadata.append(param)

assert params_with_string_metadata == ["b"]

f = _string_metadata_to_description_field(f)
type_hints = get_type_hints(f, include_extras=True)
for param, annotation in type_hints.items():
if hasattr(annotation, "__metadata__"):
metadata = annotation.__metadata__
if metadata and isinstance(metadata[0], str):
raise AssertionError("The string metadata should have been replaced with Pydantic's Field")

field_info = type_hints["b"].__metadata__[0]
assert isinstance(field_info, Field)
assert field_info.description == "b description"

if sys.version_info < (3, 11):
field_info = type_hints["c"].__args__[0].__metadata__[0]
else:
field_info = type_hints["c"].__metadata__[0]

assert isinstance(field_info, Field)
assert field_info.description == "c description"
7 changes: 7 additions & 0 deletions test/tools/test_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def test_get_parameter_json_schema() -> None:
"description": "parameter a",
"default": "3.14",
}
assert get_parameter_json_schema(
"d", Annotated[Optional[str], AG2Field(description="parameter d")], {"d": None}
) == {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "parameter d",
}

class B(BaseModel):
b: float
Expand Down

0 comments on commit d3cc09a

Please sign in to comment.