Skip to content

Commit

Permalink
FLAML moved to optional dependancies (#598)
Browse files Browse the repository at this point in the history
* flaml moved to optional dependancies

* tests fixed
  • Loading branch information
davorrunje authored Jan 21, 2025
1 parent 6db520f commit a5cd8f6
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 84 deletions.
56 changes: 33 additions & 23 deletions autogen/oai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
null_handler = logging.NullHandler()
flaml_logger.addHandler(null_handler)

from flaml import BlendSearch, tune
from flaml.tune.space import is_constant
from ..import_utils import optional_import_block, require_optional_import

with optional_import_block() as result:
from flaml import BlendSearch, tune
from flaml.tune.space import is_constant

FLAML_INSTALLED = result.is_successful

# Restore logging by removing the NullHandler
flaml_logger.removeHandler(null_handler)
Expand Down Expand Up @@ -111,26 +116,30 @@ class Completion(OpenAICompletion):
"gpt-4-32k-0613": (0.06, 0.12),
}

default_search_space = {
"model": tune.choice(
[
"text-ada-001",
"text-babbage-001",
"text-davinci-003",
"gpt-3.5-turbo",
"gpt-4",
]
),
"temperature_or_top_p": tune.choice(
[
{"temperature": tune.uniform(0, 2)},
{"top_p": tune.uniform(0, 1)},
]
),
"max_tokens": tune.lograndint(50, 1000),
"n": tune.randint(1, 100),
"prompt": "{prompt}",
}
default_search_space = (
{
"model": tune.choice(
[
"text-ada-001",
"text-babbage-001",
"text-davinci-003",
"gpt-3.5-turbo",
"gpt-4",
]
),
"temperature_or_top_p": tune.choice(
[
{"temperature": tune.uniform(0, 2)},
{"top_p": tune.uniform(0, 1)},
]
),
"max_tokens": tune.lograndint(50, 1000),
"n": tune.randint(1, 100),
"prompt": "{prompt}",
}
if FLAML_INSTALLED
else {}
)

cache_seed = 41
cache_path = f".cache/{cache_seed}"
Expand Down Expand Up @@ -525,6 +534,7 @@ def _eval(cls, config: dict, prune=True, eval_only=False):
return result

@classmethod
@require_optional_import("flaml", "flaml")
def tune(
cls,
data: list[dict],
Expand Down Expand Up @@ -1213,5 +1223,5 @@ class ChatCompletion(Completion):
"""`(openai<1)` A class for OpenAI API ChatCompletion. Share the same API as Completion."""

default_search_space = Completion.default_search_space.copy()
default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"])
default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"]) if FLAML_INSTALLED else {}
openai_completion_class = not ERROR and openai.ChatCompletion
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,9 @@ dependencies = [
"openai>=1.58",
"diskcache",
"termcolor",
"flaml",
# numpy is installed by flaml, but we want to pin the version to below 2.x (see https://github.com/microsoft/autogen/issues/1960)
"numpy>=2.1; python_version>='3.13'", # numpy 2.1+ required for Python 3.13
"numpy>=1.24.0,<2.0.0; python_version<'3.13'", # numpy 1.24+ for older Python versions
"python-dotenv",
"tiktoken",
"numpy",
# Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705
"pydantic>=2.6.1,<3",
"docker",
Expand All @@ -72,6 +69,13 @@ dependencies = [

[project.optional-dependencies]

flaml = [
"flaml",
# numpy is installed by flaml, but we want to pin the version to below 2.x (see https://github.com/microsoft/autogen/issues/1960)
"numpy>=2.1; python_version>='3.13'", # numpy 2.1+ required for Python 3.13
"numpy>=1.24.0,<2.0.0; python_version<'3.13'", # numpy 1.24+ for older Python versions
]

# public distributions
jupyter-executor = [
"jupyter-kernel-gateway",
Expand Down
24 changes: 6 additions & 18 deletions test/interop/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys
from unittest.mock import MagicMock

import pytest
from langchain.tools import tool as langchain_tool
from pydantic import BaseModel, Field

from autogen import AssistantAgent, UserProxyAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperable
from autogen.interop.langchain import LangChainInteroperability

from ...conftest import Credentials

with optional_import_block():
from langchain.tools import tool as langchain_tool


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("langchain", "interop-langchain")
class TestLangChainInteroperability:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -76,10 +76,7 @@ def test_get_unsupported_reason(self) -> None:
assert LangChainInteroperability.get_unsupported_reason() is None


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("langchain", "interop-langchain")
class TestLangChainInteroperabilityWithoutPydanticInput:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -129,12 +126,3 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
user_proxy.initiate_chat(recipient=chatbot, message="search for LangChain, Use max 100 characters", max_turns=5)

self.mock.assert_called()


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestLangChainInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
LangChainInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
30 changes: 8 additions & 22 deletions test/interop/pydantic_ai/test_pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@
# SPDX-License-Identifier: Apache-2.0

import random
import sys
from inspect import signature
from typing import Any, Optional

import pytest
from pydantic import BaseModel
from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

from autogen import AssistantAgent, UserProxyAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperable
from autogen.interop.pydantic_ai import PydanticAIInteroperability

from ...conftest import Credentials

with optional_import_block():
from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)

@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityWithotContext:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -66,9 +65,7 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
assert False, "No tool response found in chat messages"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityDependencyInjection:
def test_dependency_injection(self) -> None:
def f(
Expand Down Expand Up @@ -127,9 +124,7 @@ def f(
assert pydantic_ai_tool.current_retry == 3


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityWithContext:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -210,12 +205,3 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
return

assert False, "No tool response found in chat messages"


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestPydanticAIInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
PydanticAIInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
14 changes: 5 additions & 9 deletions test/interop/pydantic_ai/test_pydantic_ai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys

import pytest
from pydantic_ai.tools import Tool as PydanticAITool

from autogen import AssistantAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool

with optional_import_block():
from pydantic_ai.tools import Tool as PydanticAITool


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAITool:
def test_register_for_llm(self) -> None:
def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc]
Expand Down
19 changes: 11 additions & 8 deletions test/interop/test_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@

import pytest

from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperability

from ..conftest import MOCK_OPEN_AI_API_KEY

with optional_import_block():
from crewai_tools import FileReadTool

with optional_import_block():
pass # type: ignore[import]


class TestInteroperability:
@skip_on_missing_imports(["crewai_tools", "langchain", "pydantic_ai"], "interop")
def test_supported_types(self) -> None:
actual = Interoperability.get_supported_types()

Expand All @@ -28,12 +36,9 @@ def test_supported_types(self) -> None:
if sys.version_info >= (3, 13):
assert actual == ["langchain", "pydanticai"]

@pytest.mark.skipif(
sys.version_info < (3, 10) or sys.version_info >= (3, 13), reason="Only Python 3.10, 3.11, 3.12 are supported"
)
@skip_on_missing_imports("crewai_tools", "interop-crewai")
def test_crewai(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
from crewai_tools import FileReadTool

crewai_tool = FileReadTool()

Expand All @@ -56,9 +61,7 @@ def test_crewai(self, monkeypatch: pytest.MonkeyPatch) -> None:

assert tool.func(args=args) == "Hello, World!"

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@pytest.mark.skip(reason="This test is not yet implemented")
@pytest.mark.skip("This test is not yet implemented")
@skip_on_missing_imports("langchain", "interop-langchain")
def test_langchain(self) -> None:
raise NotImplementedError("This test is not yet implemented")

0 comments on commit a5cd8f6

Please sign in to comment.