diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py index 963cd88eb..dade51c76 100644 --- a/autogen/agentchat/assistant_agent.py +++ b/autogen/agentchat/assistant_agent.py @@ -6,8 +6,7 @@ # SPDX-License-Identifier: MIT from typing import Callable, Literal, Optional, Union -from autogen.runtime_logging import log_new_agent, logging_enabled - +from ..runtime_logging import log_new_agent, logging_enabled from .conversable_agent import ConversableAgent diff --git a/autogen/agentchat/contrib/capabilities/generate_images.py b/autogen/agentchat/contrib/capabilities/generate_images.py index 6c9a3a398..67429f88e 100644 --- a/autogen/agentchat/contrib/capabilities/generate_images.py +++ b/autogen/agentchat/contrib/capabilities/generate_images.py @@ -7,14 +7,17 @@ import re from typing import Any, Literal, Optional, Protocol, Union -from PIL.Image import Image from openai import OpenAI -from autogen import Agent, ConversableAgent, code_utils -from autogen.agentchat.contrib import img_utils -from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability -from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent -from autogen.cache import AbstractCache +from .... import Agent, ConversableAgent, code_utils +from ....cache import AbstractCache +from ....import_utils import optional_import_block, require_optional_import +from .. import img_utils +from ..capabilities.agent_capability import AgentCapability +from ..text_analyzer_agent import TextAnalyzerAgent + +with optional_import_block(): + from PIL.Image import Image SYSTEM_MESSAGE = "You've been given the special ability to generate images." DESCRIPTION_MESSAGE = "This agent has the ability to generate images." @@ -34,7 +37,7 @@ class ImageGenerator(Protocol): NOTE: Current implementation does not allow you to edit a previously existing image. """ - def generate_image(self, prompt: str) -> Image: + def generate_image(self, prompt: str) -> "Image": """Generates an image based on the provided prompt. Args: @@ -62,6 +65,7 @@ def cache_key(self, prompt: str) -> str: ... +@require_optional_import("PIL", "unknown") class DalleImageGenerator: """Generates images using OpenAI's DALL-E models. @@ -94,7 +98,7 @@ def __init__( self._num_images = num_images self._dalle_client = OpenAI(api_key=config_list[0]["api_key"]) - def generate_image(self, prompt: str) -> Image: + def generate_image(self, prompt: str) -> "Image": response = self._dalle_client.images.generate( model=self._model, prompt=prompt, @@ -114,6 +118,7 @@ def cache_key(self, prompt: str) -> str: return ",".join([str(k) for k in keys]) +@require_optional_import("PIL", "unknown") class ImageGeneration(AgentCapability): """This capability allows a ConversableAgent to generate images based on the message received from other Agents. @@ -253,7 +258,7 @@ def _extract_prompt(self, last_message) -> str: analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions) return self._extract_analysis(analysis) - def _cache_get(self, prompt: str) -> Optional[Image]: + def _cache_get(self, prompt: str) -> Optional["Image"]: if self._cache: key = self._image_generator.cache_key(prompt) cached_value = self._cache.get(key) @@ -261,7 +266,7 @@ def _cache_get(self, prompt: str) -> Optional[Image]: if cached_value: return img_utils.get_pil_image(cached_value) - def _cache_set(self, prompt: str, image: Image): + def _cache_set(self, prompt: str, image: "Image"): if self._cache: key = self._image_generator.cache_key(prompt) self._cache.set(key, img_utils.pil_to_data_uri(image)) @@ -272,7 +277,7 @@ def _extract_analysis(self, analysis: Union[str, dict, None]) -> str: else: return code_utils.content_str(analysis) - def _generate_content_message(self, prompt: str, image: Image) -> dict[str, Any]: + def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]: return { "content": [ {"type": "text", "text": f"I generated an image with the prompt: {prompt}"}, diff --git a/autogen/agentchat/contrib/img_utils.py b/autogen/agentchat/contrib/img_utils.py index 7dcd52f53..e1ca180b8 100644 --- a/autogen/agentchat/contrib/img_utils.py +++ b/autogen/agentchat/contrib/img_utils.py @@ -13,7 +13,11 @@ from typing import Union import requests -from PIL import Image + +from ...import_utils import optional_import_block, require_optional_import + +with optional_import_block(): + from PIL import Image from autogen.agentchat import utils @@ -37,7 +41,8 @@ } -def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image: +@require_optional_import("PIL", "unknown") +def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image": """Loads an image from a file and returns a PIL Image object. Parameters: @@ -75,7 +80,8 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image: return image.convert("RGB") -def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes: +@require_optional_import("PIL", "unknown") +def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes: """Loads an image and returns its data either as raw bytes or in base64-encoded format. This function first loads an image from the specified file, URL, or base64 string using @@ -105,6 +111,7 @@ def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes: return content +@require_optional_import("PIL", "unknown") def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]: """Formats the input prompt by replacing image tags and returns the new prompt along with image locations. @@ -149,7 +156,8 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, return new_prompt, images -def pil_to_data_uri(image: Image.Image) -> str: +@require_optional_import("PIL", "unknown") +def pil_to_data_uri(image: "Image.Image") -> str: """Converts a PIL Image object to a data URI. Parameters: @@ -184,6 +192,7 @@ def _get_mime_type_from_data_uri(base64_image): return data_uri +@require_optional_import("PIL", "unknown") def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict]]: """Formats the input prompt by replacing image tags and returns a list of text and images. @@ -251,7 +260,8 @@ def extract_img_paths(paragraph: str) -> list: return img_paths -def _to_pil(data: str) -> Image.Image: +@require_optional_import("PIL", "unknown") +def _to_pil(data: str) -> "Image.Image": """Converts a base64 encoded image data string to a PIL Image object. This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes, @@ -266,6 +276,7 @@ def _to_pil(data: str) -> Image.Image: return Image.open(BytesIO(base64.b64decode(data))) +@require_optional_import("PIL", "unknown") def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]: """Converts the PIL image URLs in the messages to base64 encoded data URIs. @@ -321,8 +332,9 @@ def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]: return new_messages +@require_optional_import("PIL", "unknown") def num_tokens_from_gpt_image( - image_data: Union[str, Image.Image], model: str = "gpt-4-vision", low_quality: bool = False + image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False ) -> int: """Calculate the number of tokens required to process an image based on its dimensions after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini". diff --git a/autogen/agentchat/realtime_agent/audio_adapters/twilio_audio_adapter.py b/autogen/agentchat/realtime_agent/audio_adapters/twilio_audio_adapter.py index 69add567f..1ae1eac05 100644 --- a/autogen/agentchat/realtime_agent/audio_adapters/twilio_audio_adapter.py +++ b/autogen/agentchat/realtime_agent/audio_adapters/twilio_audio_adapter.py @@ -11,7 +11,7 @@ from ..realtime_observer import RealtimeObserver if TYPE_CHECKING: - from fastapi.websockets import WebSocket + from ..websockets import WebSocketProtocol as WebSocket LOG_EVENT_TYPES = [ @@ -142,5 +142,5 @@ async def initialize_session(self) -> None: if TYPE_CHECKING: - def twilio_audio_adapter(websocket: WebSocket) -> RealtimeObserver: + def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver: return TwilioAudioAdapter(websocket) diff --git a/autogen/agentchat/realtime_agent/audio_adapters/websocket_audio_adapter.py b/autogen/agentchat/realtime_agent/audio_adapters/websocket_audio_adapter.py index f537563ae..e56978b74 100644 --- a/autogen/agentchat/realtime_agent/audio_adapters/websocket_audio_adapter.py +++ b/autogen/agentchat/realtime_agent/audio_adapters/websocket_audio_adapter.py @@ -8,12 +8,10 @@ from typing import TYPE_CHECKING, Optional from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted +from ..realtime_observer import RealtimeObserver if TYPE_CHECKING: - from fastapi.websockets import WebSocket - - -from ..realtime_observer import RealtimeObserver + from ..websockets import WebSocketProtocol as WebSocket LOG_EVENT_TYPES = [ "error", @@ -135,5 +133,5 @@ async def run_loop(self) -> None: if TYPE_CHECKING: - def websocket_audio_adapter(websocket: WebSocket) -> RealtimeObserver: + def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver: return WebSocketAudioAdapter(websocket) diff --git a/autogen/agentchat/realtime_agent/clients/oai/rtc_client.py b/autogen/agentchat/realtime_agent/clients/oai/rtc_client.py index e4938198a..f7edf7e1c 100644 --- a/autogen/agentchat/realtime_agent/clients/oai/rtc_client.py +++ b/autogen/agentchat/realtime_agent/clients/oai/rtc_client.py @@ -15,8 +15,8 @@ from .utils import parse_oai_message if TYPE_CHECKING: + from ...websockets import WebSocketProtocol as WebSocket from ..realtime_client import RealtimeClientProtocol - from ..websockets import WebSocketProtocol as WebSocket __all__ = ["OpenAIRealtimeWebRTCClient"] diff --git a/autogen/agentchat/realtime_agent/clients/websockets.py b/autogen/agentchat/realtime_agent/websockets.py similarity index 80% rename from autogen/agentchat/realtime_agent/clients/websockets.py rename to autogen/agentchat/realtime_agent/websockets.py index 3198224c0..3353fb1cb 100644 --- a/autogen/agentchat/realtime_agent/clients/websockets.py +++ b/autogen/agentchat/realtime_agent/websockets.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Protocol, runtime_checkable +from typing import Any, AsyncIterator, Protocol, runtime_checkable __all__ = ["WebSocketProtocol"] @@ -16,3 +16,5 @@ async def send_json(self, data: Any, mode: str = "text") -> None: ... async def receive_json(self, mode: str = "text") -> Any: ... async def receive_text(self) -> str: ... + + def iter_text(self) -> AsyncIterator[str]: ... diff --git a/autogen/cache/cache.py b/autogen/cache/cache.py index a4206f246..63e2c7e4d 100644 --- a/autogen/cache/cache.py +++ b/autogen/cache/cache.py @@ -6,18 +6,12 @@ # SPDX-License-Identifier: MIT from __future__ import annotations -import sys from types import TracebackType from typing import Any from .abstract_cache_base import AbstractCache from .cache_factory import CacheFactory -if sys.version_info >= (3, 11): - pass -else: - pass - class Cache(AbstractCache): """A wrapper class for managing cache configuration and instances. diff --git a/autogen/cache/cosmos_db_cache.py b/autogen/cache/cosmos_db_cache.py index de558986e..64fe262bc 100644 --- a/autogen/cache/cosmos_db_cache.py +++ b/autogen/cache/cosmos_db_cache.py @@ -9,20 +9,24 @@ import pickle from typing import Any, Optional, TypedDict, Union -from azure.cosmos import CosmosClient, PartitionKey -from azure.cosmos.exceptions import CosmosResourceNotFoundError +from ..import_utils import optional_import_block, require_optional_import +from .abstract_cache_base import AbstractCache -from autogen.cache.abstract_cache_base import AbstractCache +with optional_import_block(): + from azure.cosmos import CosmosClient, PartitionKey + from azure.cosmos.exceptions import CosmosResourceNotFoundError +@require_optional_import("azure", "cosmosdb") class CosmosDBConfig(TypedDict, total=False): connection_string: str database_id: str container_id: str cache_seed: Optional[Union[str, int]] - client: Optional[CosmosClient] + client: Optional["CosmosClient"] +@require_optional_import("azure", "cosmosdb") class CosmosDBCache(AbstractCache): """Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API. @@ -75,7 +79,7 @@ def from_connection_string(cls, seed: Union[str, int], connection_string: str, d return cls(str(seed), config) @classmethod - def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str): + def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str): config = {"client": client, "database_id": database_id, "container_id": container_id} return cls(str(seed), config) diff --git a/autogen/cache/redis_cache.py b/autogen/cache/redis_cache.py index 6453285e6..cc618cad0 100644 --- a/autogen/cache/redis_cache.py +++ b/autogen/cache/redis_cache.py @@ -9,16 +9,19 @@ from types import TracebackType from typing import Any, Optional, Union -import redis - -from .abstract_cache_base import AbstractCache - if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self +from ..import_utils import optional_import_block, require_optional_import +from .abstract_cache_base import AbstractCache + +with optional_import_block(): + import redis + +@require_optional_import("redis", "redis") class RedisCache(AbstractCache): """Implementation of AbstractCache using the Redis database. diff --git a/autogen/code_utils.py b/autogen/code_utils.py index b282234bd..07d0bc24b 100644 --- a/autogen/code_utils.py +++ b/autogen/code_utils.py @@ -20,8 +20,7 @@ import docker -from autogen import oai - +from . import oai from .types import UserMessageImageContentPart, UserMessageTextContentPart SENTINEL = object() diff --git a/autogen/coding/jupyter/docker_jupyter_server.py b/autogen/coding/jupyter/docker_jupyter_server.py index f52dd81d2..2191a1b22 100644 --- a/autogen/coding/jupyter/docker_jupyter_server.py +++ b/autogen/coding/jupyter/docker_jupyter_server.py @@ -17,14 +17,12 @@ import docker -from ..docker_commandline_code_executor import _wait_for_ready - if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self - +from ..docker_commandline_code_executor import _wait_for_ready from .base import JupyterConnectable, JupyterConnectionInfo from .jupyter_client import JupyterClient diff --git a/autogen/coding/jupyter/embedded_ipython_code_executor.py b/autogen/coding/jupyter/embedded_ipython_code_executor.py index e5f4fc230..b02c3af78 100644 --- a/autogen/coding/jupyter/embedded_ipython_code_executor.py +++ b/autogen/coding/jupyter/embedded_ipython_code_executor.py @@ -13,16 +13,20 @@ from queue import Empty from typing import Any -from jupyter_client import KernelManager # type: ignore[attr-defined] -from jupyter_client.kernelspec import KernelSpecManager from pydantic import BaseModel, Field, field_validator +from ...import_utils import optional_import_block, require_optional_import from ..base import CodeBlock, CodeExtractor, IPythonCodeResult from ..markdown_code_extractor import MarkdownCodeExtractor +with optional_import_block(): + from jupyter_client import KernelManager # type: ignore[attr-defined] + from jupyter_client.kernelspec import KernelSpecManager + __all__ = "EmbeddedIPythonCodeExecutor" +@require_optional_import("jupyter_client", "jupyter-executor") class EmbeddedIPythonCodeExecutor(BaseModel): """(Experimental) A code executor class that executes code statefully using an embedded IPython kernel managed by this class. diff --git a/autogen/coding/jupyter/jupyter_client.py b/autogen/coding/jupyter/jupyter_client.py index 764e308fc..49b8f62dc 100644 --- a/autogen/coding/jupyter/jupyter_client.py +++ b/autogen/coding/jupyter/jupyter_client.py @@ -6,7 +6,10 @@ # SPDX-License-Identifier: MIT from __future__ import annotations +import datetime +import json import sys +import uuid from dataclasses import dataclass from types import TracebackType from typing import Any, cast @@ -16,17 +19,17 @@ else: from typing_extensions import Self -import datetime -import json -import uuid import requests -import websocket from requests.adapters import HTTPAdapter, Retry -from websocket import WebSocket +from ...import_utils import optional_import_block, require_optional_import from .base import JupyterConnectionInfo +with optional_import_block(): + import websocket + from websocket import WebSocket + class JupyterClient: def __init__(self, connection_info: JupyterConnectionInfo): @@ -90,12 +93,14 @@ def restart_kernel(self, kernel_id: str) -> None: ) response.raise_for_status() + @require_optional_import("websocket", "jupyter-executor") def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient: ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" ws = websocket.create_connection(ws_url, header=self._get_headers()) return JupyterKernelClient(ws) +@require_optional_import("websocket", "jupyter-executor") class JupyterKernelClient: """(Experimental) A client for communicating with a Jupyter kernel.""" @@ -110,7 +115,7 @@ class DataItem: output: str data_items: list[DataItem] - def __init__(self, websocket: WebSocket): + def __init__(self, websocket: "WebSocket"): self._session_id: str = uuid.uuid4().hex self._websocket: WebSocket = websocket diff --git a/autogen/coding/jupyter/jupyter_code_executor.py b/autogen/coding/jupyter/jupyter_code_executor.py index a29778e81..6e0ef6aa8 100644 --- a/autogen/coding/jupyter/jupyter_code_executor.py +++ b/autogen/coding/jupyter/jupyter_code_executor.py @@ -13,16 +13,14 @@ from types import TracebackType from typing import Optional, Union -from autogen.coding.utils import silence_pip - if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self - from ..base import CodeBlock, CodeExecutor, CodeExtractor, IPythonCodeResult from ..markdown_code_extractor import MarkdownCodeExtractor +from ..utils import silence_pip from .base import JupyterConnectable, JupyterConnectionInfo from .jupyter_client import JupyterClient diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py index 98419d308..f2a288fdd 100644 --- a/autogen/coding/local_commandline_code_executor.py +++ b/autogen/coding/local_commandline_code_executor.py @@ -18,15 +18,14 @@ from typing_extensions import ParamSpec -from autogen.coding.func_with_reqs import ( +from ..code_utils import PYTHON_VARIANTS, TIMEOUT_MSG, WIN32, _cmd +from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult +from .func_with_reqs import ( FunctionWithRequirements, FunctionWithRequirementsStr, _build_python_functions_file, to_stub, ) - -from ..code_utils import PYTHON_VARIANTS, TIMEOUT_MSG, WIN32, _cmd -from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult from .markdown_code_extractor import MarkdownCodeExtractor from .utils import _get_file_name_from_content, silence_pip diff --git a/autogen/graph_utils.py b/autogen/graph_utils.py index fe77b0301..edf536ef3 100644 --- a/autogen/graph_utils.py +++ b/autogen/graph_utils.py @@ -7,7 +7,7 @@ import logging from typing import Optional -from autogen.agentchat import Agent +from .agentchat import Agent def has_self_loops(allowed_speaker_transitions: dict) -> bool: diff --git a/autogen/import_utils.py b/autogen/import_utils.py index 0c179df09..8a526f64b 100644 --- a/autogen/import_utils.py +++ b/autogen/import_utils.py @@ -10,7 +10,7 @@ from logging import getLogger from typing import Any, Callable, Generator, Iterable, Optional, Type, TypeVar, Union -__all__ = ["optional_import_block", "require_optional_import"] +__all__ = ["optional_import_block", "require_optional_import", "skip_on_missing_imports"] logger = getLogger(__name__) @@ -262,3 +262,33 @@ def decorator(o: T) -> T: return patch_object(o, missing_modules=missing_modules, dep_target=dep_target) return decorator + + +def skip_on_missing_imports(modules: Union[str, Iterable[str]], dep_target: str) -> Callable[[T], T]: + """Decorator to skip a test if an optional module is missing + + Args: + module: Module name + dep_target: Target name for pip installation (e.g. 'test' in pip install ag2[test]) + """ + + missing_modules = get_missing_imports(modules) + + if not missing_modules: + + def decorator(o: T) -> T: + return o + else: + + def decorator(o: T) -> T: + import pytest + + @pytest.mark.skip( + f"Missing module{'s' if len(missing_modules) > 1 else ''}: {', '.join(missing_modules)}. Install using 'pip install ag2[{dep_target}]'" + ) + def _skip(*args, **kwargs): + pass + + return _skip + + return decorator diff --git a/autogen/io/base.py b/autogen/io/base.py index be49bc6e9..8b5a59404 100644 --- a/autogen/io/base.py +++ b/autogen/io/base.py @@ -10,7 +10,7 @@ from contextvars import ContextVar from typing import Any, Optional, Protocol, runtime_checkable -from autogen.messages.base_message import BaseMessage +from ..messages.base_message import BaseMessage __all__ = ("IOStream", "InputStream", "OutputStream") diff --git a/autogen/io/console.py b/autogen/io/console.py index eb6e2989f..679a34009 100644 --- a/autogen/io/console.py +++ b/autogen/io/console.py @@ -7,9 +7,8 @@ import getpass from typing import Any -from autogen.messages.base_message import BaseMessage -from autogen.messages.print_message import PrintMessage - +from ..messages.base_message import BaseMessage +from ..messages.print_message import PrintMessage from .base import IOStream __all__ = ("IOConsole",) diff --git a/autogen/logger/base_logger.py b/autogen/logger/base_logger.py index 3943265bf..1dea88bd3 100644 --- a/autogen/logger/base_logger.py +++ b/autogen/logger/base_logger.py @@ -15,7 +15,7 @@ from openai.types.chat import ChatCompletion if TYPE_CHECKING: - from autogen import Agent, ConversableAgent, OpenAIWrapper + from .. import Agent, ConversableAgent, OpenAIWrapper F = TypeVar("F", bound=Callable[..., Any]) ConfigItem = dict[str, Union[str, list[str]]] diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 3b0433147..de351d57d 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -16,22 +16,20 @@ from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion -from autogen.logger.base_logger import BaseLogger -from autogen.logger.logger_utils import get_current_ts, to_dict - -from .base_logger import LLMConfig +from .base_logger import BaseLogger, LLMConfig +from .logger_utils import get_current_ts, to_dict if TYPE_CHECKING: - from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.bedrock import BedrockClient - from autogen.oai.cerebras import CerebrasClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.groq import GroqClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.ollama import OllamaClient - from autogen.oai.together import TogetherClient + from .. import Agent, ConversableAgent, OpenAIWrapper + from ..oai.anthropic import AnthropicClient + from ..oai.bedrock import BedrockClient + from ..oai.cerebras import CerebrasClient + from ..oai.cohere import CohereClient + from ..oai.gemini import GeminiClient + from ..oai.groq import GroqClient + from ..oai.mistral import MistralAIClient + from ..oai.ollama import OllamaClient + from ..oai.together import TogetherClient logger = logging.getLogger(__name__) diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py index 667020310..4275c6578 100644 --- a/autogen/logger/logger_factory.py +++ b/autogen/logger/logger_factory.py @@ -6,9 +6,9 @@ # SPDX-License-Identifier: MIT from typing import Any, Literal, Optional -from autogen.logger.base_logger import BaseLogger -from autogen.logger.file_logger import FileLogger -from autogen.logger.sqlite_logger import SqliteLogger +from .base_logger import BaseLogger +from .file_logger import FileLogger +from .sqlite_logger import SqliteLogger __all__ = ("LoggerFactory",) diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 5c4d0d00a..9c250ae74 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -17,22 +17,20 @@ from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion -from autogen.logger.base_logger import BaseLogger -from autogen.logger.logger_utils import get_current_ts, to_dict - -from .base_logger import LLMConfig +from .base_logger import BaseLogger, LLMConfig +from .logger_utils import get_current_ts, to_dict if TYPE_CHECKING: - from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.bedrock import BedrockClient - from autogen.oai.cerebras import CerebrasClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.groq import GroqClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.ollama import OllamaClient - from autogen.oai.together import TogetherClient + from .. import Agent, ConversableAgent, OpenAIWrapper + from ..oai.anthropic import AnthropicClient + from ..oai.bedrock import BedrockClient + from ..oai.cerebras import CerebrasClient + from ..oai.cohere import CohereClient + from ..oai.gemini import GeminiClient + from ..oai.groq import GroqClient + from ..oai.mistral import MistralAIClient + from ..oai.ollama import OllamaClient + from ..oai.together import TogetherClient logger = logging.getLogger(__name__) lock = threading.Lock() diff --git a/autogen/math_utils.py b/autogen/math_utils.py index 6b427f9f1..197dfd38d 100644 --- a/autogen/math_utils.py +++ b/autogen/math_utils.py @@ -6,7 +6,7 @@ # SPDX-License-Identifier: MIT from typing import Optional -from autogen import DEFAULT_MODEL, oai +from . import DEFAULT_MODEL, oai _MATH_PROMPT = "{problem} Solve the problem carefully. Simplify your answer as much as possible. Put the final answer in \\boxed{{}}." _MATH_CONFIG = { diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index 68113589b..7e4d258b2 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -4,10 +4,10 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT -from autogen.cache.cache import Cache -from autogen.oai.client import ModelClient, OpenAIWrapper -from autogen.oai.completion import ChatCompletion, Completion -from autogen.oai.openai_utils import ( +from ..cache.cache import Cache +from .client import ModelClient, OpenAIWrapper +from .completion import ChatCompletion, Completion +from .openai_utils import ( config_list_from_dotenv, config_list_from_json, config_list_from_models, diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index b3ed79887..ec755c74d 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -78,19 +78,22 @@ import warnings from typing import Any, Optional, Type -from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex -from anthropic import __version__ as anthropic_version -from anthropic.types import Message, TextBlock, ToolUseBlock from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel -from autogen.oai.client_utils import FormatterProtocol, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import FormatterProtocol, validate_parameter -TOOL_ENABLED = anthropic_version >= "0.23.1" -if TOOL_ENABLED: - pass +with optional_import_block(): + from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex + from anthropic import __version__ as anthropic_version + from anthropic.types import Message, TextBlock, ToolUseBlock + + TOOL_ENABLED = anthropic_version >= "0.23.1" + if TOOL_ENABLED: + pass ANTHROPIC_PRICING_1k = { @@ -106,6 +109,7 @@ } +@require_optional_import("anthropic", "anthropic") class AnthropicClient: def __init__(self, **kwargs: Any): """Initialize the Anthropic API client. @@ -393,7 +397,7 @@ def _add_response_format_to_system(self, params: dict[str, Any]): # Add formatting to last user message params["system"] += "\n\n" + format_content - def _extract_json_response(self, response: Message) -> Any: + def _extract_json_response(self, response: "Message") -> Any: """Extract and validate JSON response from the output for structured outputs. Args: @@ -435,6 +439,7 @@ def _format_json_response(response: Any) -> str: return response.format() if isinstance(response, FormatterProtocol) else response +@require_optional_import("anthropic", "anthropic") def oai_messages_to_anthropic_messages(params: dict[str, Any]) -> list[dict[str, Any]]: """Convert messages from OAI format to Anthropic format. We correct for any specific role orders and types, etc. diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py index 3c5b994f5..d4d73f609 100644 --- a/autogen/oai/bedrock.py +++ b/autogen/oai/bedrock.py @@ -38,16 +38,20 @@ import warnings from typing import Any, Literal -import boto3 import requests -from botocore.config import Config from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import validate_parameter +with optional_import_block(): + import boto3 + from botocore.config import Config + +@require_optional_import("boto3", "bedrock") class BedrockClient: """Client for Amazon's Bedrock Converse API.""" diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index ffdfd88c4..baf049fc7 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -29,12 +29,15 @@ import warnings from typing import Any -from cerebras.cloud.sdk import Cerebras, Stream from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import should_hide_tools, validate_parameter + +with optional_import_block(): + from cerebras.cloud.sdk import Cerebras, Stream CEREBRAS_PRICING_1K = { # Convert pricing per million to per thousand tokens. @@ -111,6 +114,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return cerebras_params + @require_optional_import("cerebras", "cerebras") def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 8cb4adf18..34cebd063 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -37,13 +37,17 @@ import warnings from typing import Any -from cohere import Client as Cohere -from cohere.types import ToolParameterDefinitionsValue, ToolResult from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import logging_formatter, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import logging_formatter, validate_parameter + +with optional_import_block(): + from cohere import Client as Cohere + from cohere.types import ToolParameterDefinitionsValue, ToolResult + logger = logging.getLogger(__name__) if not logger.handlers: @@ -151,6 +155,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return cohere_params + @require_optional_import("cohere", "cohere") def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) client_name = params.get("client_name") or "autogen-cohere" @@ -262,6 +267,7 @@ def create(self, params: dict) -> ChatCompletion: return response_oai +@require_optional_import("cohere", "cohere") def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> list[dict[str, Any]]: temp_tool_results = [] @@ -278,6 +284,7 @@ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_t return temp_tool_results +@require_optional_import("cohere", "cohere") def oai_messages_to_cohere_messages( messages: list[dict[str, Any]], params: dict[str, Any], cohere_params: dict[str, Any] ) -> tuple[list[dict[str, Any]], str, str]: diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 453de92d2..9f03f0e65 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -52,40 +52,44 @@ from io import BytesIO from typing import Any, Optional, Type -import google.generativeai as genai import requests -import vertexai -from PIL import Image -from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool -from google.ai.generativelanguage_v1beta.types import Schema -from google.auth.credentials import Credentials -from google.generativeai.types import GenerateContentResponse -from jsonschema import ValidationError from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel -from vertexai.generative_models import ( - Content as VertexAIContent, -) -from vertexai.generative_models import FunctionDeclaration as vaiFunctionDeclaration -from vertexai.generative_models import ( - GenerationResponse as VertexAIGenerationResponse, -) -from vertexai.generative_models import GenerativeModel -from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold -from vertexai.generative_models import HarmCategory as VertexAIHarmCategory -from vertexai.generative_models import Part as VertexAIPart -from vertexai.generative_models import SafetySetting as VertexAISafetySetting -from vertexai.generative_models import ( - Tool as vaiTool, -) - -from autogen.oai.client_utils import FormatterProtocol + +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import FormatterProtocol + +with optional_import_block(): + import google.generativeai as genai + import vertexai + from PIL import Image + from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool + from google.ai.generativelanguage_v1beta.types import Schema + from google.auth.credentials import Credentials + from google.generativeai.types import GenerateContentResponse + from jsonschema import ValidationError + from vertexai.generative_models import ( + Content as VertexAIContent, + ) + from vertexai.generative_models import FunctionDeclaration as vaiFunctionDeclaration + from vertexai.generative_models import ( + GenerationResponse as VertexAIGenerationResponse, + ) + from vertexai.generative_models import GenerativeModel + from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold + from vertexai.generative_models import HarmCategory as VertexAIHarmCategory + from vertexai.generative_models import Part as VertexAIPart + from vertexai.generative_models import SafetySetting as VertexAISafetySetting + from vertexai.generative_models import ( + Tool as vaiTool, + ) logger = logging.getLogger(__name__) +@require_optional_import(["google", "vertexai", "PIL", "jsonschema"], "gemini") class GeminiClient: """Client for Google's Gemini API.""" @@ -446,7 +450,7 @@ def _oai_content_to_gemini_content(self, message: dict[str, Any]) -> tuple[list, else: raise Exception("Unable to convert content to Gemini format.") - def _concat_parts(self, parts: list[Part]) -> list: + def _concat_parts(self, parts: list["Part"]) -> list: """Concatenate parts with the same type. If two adjacent parts both have the "text" attribute, then it will be joined into one part. """ @@ -563,7 +567,7 @@ def _convert_json_response(self, response: str) -> Any: f"Failed to parse response as valid JSON matching the schema for Structured Output: {str(e)}" ) - def _tools_to_gemini_tools(self, tools: list[dict[str, Any]]) -> list[Tool]: + def _tools_to_gemini_tools(self, tools: list[dict[str, Any]]) -> list["Tool"]: """Create Gemini tools (as typically requires Callables)""" functions = [] for tool in tools: @@ -583,7 +587,7 @@ def _tools_to_gemini_tools(self, tools: list[dict[str, Any]]) -> list[Tool]: return [Tool(function_declarations=functions)] @staticmethod - def _create_gemini_function_declaration(tool: dict) -> FunctionDeclaration: + def _create_gemini_function_declaration(tool: dict) -> "FunctionDeclaration": function_declaration = FunctionDeclaration() function_declaration.name = tool["function"]["name"] function_declaration.description = tool["function"]["description"] @@ -595,7 +599,7 @@ def _create_gemini_function_declaration(tool: dict) -> FunctionDeclaration: return function_declaration @staticmethod - def _create_gemini_function_declaration_schema(json_data) -> Schema: + def _create_gemini_function_declaration_schema(json_data) -> "Schema": """Recursively creates Schema objects for FunctionDeclaration.""" param_schema = Schema() param_type = json_data["type"] @@ -705,6 +709,7 @@ def _to_json_or_str(data: str) -> dict | str: return data +@require_optional_import(["PIL"], "gemini") def get_image_data(image_file: str, use_b64=True) -> bytes: if image_file.startswith("http://") or image_file.startswith("https://"): response = requests.get(image_file) diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index 59d48b7e4..61310de22 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -29,12 +29,15 @@ import warnings from typing import Any -from groq import Groq, Stream from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import should_hide_tools, validate_parameter + +with optional_import_block(): + from groq import Groq, Stream # Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K) GROQ_PRICING_1K = { @@ -126,6 +129,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return groq_params + @require_optional_import("groq", "groq") def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py index 24adbc11d..9299c8b38 100644 --- a/autogen/oai/mistral.py +++ b/autogen/oai/mistral.py @@ -31,25 +31,29 @@ import warnings from typing import Any, Union -# Mistral libraries -# pip install mistralai -from mistralai import ( - AssistantMessage, - Function, - FunctionCall, - Mistral, - SystemMessage, - ToolCall, - ToolMessage, - UserMessage, -) from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter - - +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import should_hide_tools, validate_parameter + +with optional_import_block(): + # Mistral libraries + # pip install mistralai + from mistralai import ( + AssistantMessage, + Function, + FunctionCall, + Mistral, + SystemMessage, + ToolCall, + ToolMessage, + UserMessage, + ) + + +@require_optional_import("mistralai", "mistral") class MistralAIClient: """Client for Mistral.AI's API.""" @@ -80,6 +84,7 @@ def message_retrieval(self, response: ChatCompletion) -> Union[list[str], list[C def cost(self, response) -> float: return response.cost + @require_optional_import("mistralai", "mistral") def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" mistral_params = {} @@ -169,6 +174,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return mistral_params + @require_optional_import("mistralai", "mistral") def create(self, params: dict[str, Any]) -> ChatCompletion: # 1. Parse parameters to Mistral.AI API's parameters mistral_params = self.parse_params(params) @@ -232,6 +238,7 @@ def get_usage(response: ChatCompletion) -> dict: } +@require_optional_import("mistralai", "mistral") def tool_def_to_mistral(tool_definitions: list[dict[str, Any]]) -> list[dict[str, Any]]: """Converts AutoGen tool definition to a mistral tool format""" mistral_tools = [] diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 2b1bcf8b5..b4aa3dc8b 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -30,15 +30,18 @@ import warnings from typing import Any, Optional, Type -import ollama -from fix_busted_json import repair_json -from ollama import Client from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel -from autogen.oai.client_utils import FormatterProtocol, should_hide_tools, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import FormatterProtocol, should_hide_tools, validate_parameter + +with optional_import_block(): + import ollama + from fix_busted_json import repair_json + from ollama import Client class OllamaClient: @@ -176,6 +179,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return ollama_params + @require_optional_import(["ollama", "fix_busted_json"], "ollama") def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) @@ -497,6 +501,7 @@ def _format_json_response(response: Any, original_answer: str) -> str: return response.format() if isinstance(response, FormatterProtocol) else original_answer +@require_optional_import("fix_busted_json", "ollama") def response_to_tool_call(response_string: str) -> Any: """Attempts to convert the response to an object, aimed to align with function format `[{},{}]`""" # We try and detect the list[dict] format: diff --git a/autogen/oai/together.py b/autogen/oai/together.py index f66fa948c..caedaa4de 100644 --- a/autogen/oai/together.py +++ b/autogen/oai/together.py @@ -38,9 +38,12 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from together import Together -from autogen.oai.client_utils import should_hide_tools, validate_parameter +from ..import_utils import optional_import_block, require_optional_import +from .client_utils import should_hide_tools, validate_parameter + +with optional_import_block(): + from together import Together class TogetherClient: @@ -129,6 +132,7 @@ def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: return together_params + @require_optional_import("together", "together") def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 7fe4ac521..65c79e186 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -6,29 +6,30 @@ # SPDX-License-Identifier: MIT import glob import hashlib +import logging import os import re from typing import Callable, Union from urllib.parse import urlparse -import chromadb -import markdownify import requests -from bs4 import BeautifulSoup -if chromadb.__version__ < "0.4.15": - from chromadb.api import API -else: - from chromadb.api import ClientAPI as API # noqa: N814 -import logging +from .import_utils import optional_import_block, require_optional_import +from .token_count_utils import count_token -import chromadb.utils.embedding_functions as ef -import pypdf -from chromadb.api.types import QueryResult +with optional_import_block(): + import chromadb + import markdownify + from bs4 import BeautifulSoup -from autogen.token_count_utils import count_token + if chromadb.__version__ < "0.4.15": + from chromadb.api import API + else: + from chromadb.api import ClientAPI as API # noqa: N814 + import chromadb.utils.embedding_functions as ef + import pypdf + from chromadb.api.types import QueryResult -from .import_utils import optional_import_block with optional_import_block() as result: from unstructured.partition.auto import partition @@ -136,6 +137,7 @@ def split_text_to_chunks( return chunks +@require_optional_import("pypdf", "retrievechat") def extract_text_from_pdf(file: str) -> str: """Extract text from PDF files""" text = "" @@ -251,6 +253,7 @@ def get_files_from_dir(dir_path: Union[str, list[str]], types: list = TEXT_FORMA return files +@require_optional_import(["markdownify", "bs4"], "retrievechat") def parse_html_to_markdown(html: str, url: str = None) -> str: """Parse HTML to markdown.""" soup = BeautifulSoup(html, "html.parser") @@ -339,10 +342,11 @@ def is_url(string: str): return False +@require_optional_import("chromadb", "retrievechat") def create_vector_db_from_dir( dir_path: Union[str, list[str]], max_tokens: int = 4000, - client: API = None, + client: "API" = None, db_path: str = "tmp/chromadb.db", collection_name: str = "all-my-documents", get_or_create: bool = False, @@ -354,7 +358,7 @@ def create_vector_db_from_dir( custom_text_types: list[str] = TEXT_FORMATS, recursive: bool = True, extra_docs: bool = False, -) -> API: +) -> "API": """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to a single file. We support chromadb compatible APIs to create the vector db, this function is not required if you prepared your own vector db. @@ -431,16 +435,17 @@ def create_vector_db_from_dir( return client +@require_optional_import("chromadb", "retrievechat") def query_vector_db( query_texts: list[str], n_results: int = 10, - client: API = None, + client: "API" = None, db_path: str = "tmp/chromadb.db", collection_name: str = "all-my-documents", search_string: str = "", embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, -) -> QueryResult: +) -> "QueryResult": """Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db and query function. diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 1f597cde3..af2f7021f 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -14,20 +14,20 @@ from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion -from autogen.logger.base_logger import BaseLogger, LLMConfig -from autogen.logger.logger_factory import LoggerFactory +from .logger.base_logger import BaseLogger, LLMConfig +from .logger.logger_factory import LoggerFactory if TYPE_CHECKING: - from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.bedrock import BedrockClient - from autogen.oai.cerebras import CerebrasClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.groq import GroqClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.ollama import OllamaClient - from autogen.oai.together import TogetherClient + from . import Agent, ConversableAgent, OpenAIWrapper + from .oai.anthropic import AnthropicClient + from .oai.bedrock import BedrockClient + from .oai.cerebras import CerebrasClient + from .oai.cohere import CohereClient + from .oai.gemini import GeminiClient + from .oai.groq import GroqClient + from .oai.mistral import MistralAIClient + from .oai.ollama import OllamaClient + from .oai.together import TogetherClient logger = logging.getLogger(__name__) diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index aa3c1c3d7..8ba5534d9 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -11,14 +11,16 @@ import tiktoken +from .agentchat.contrib.img_utils import num_tokens_from_gpt_image from .import_utils import optional_import_block +# if PIL is not imported, we will redefine num_tokens_from_gpt_image to return 0 tokens for images +# Otherwise, it would raise an ImportError with optional_import_block() as result: - from autogen.agentchat.contrib.img_utils import num_tokens_from_gpt_image + import PIL # noqa: F401 -img_util_imported = result.is_successful - -if not result.is_successful: +pil_imported = result.is_successful +if not pil_imported: def num_tokens_from_gpt_image(*args, **kwargs): return 0 @@ -181,7 +183,7 @@ def _num_token_from_messages(messages: Union[list, dict], model="gpt-3.5-turbo-0 num_tokens += len(encoding.encode(part["text"])) if "image_url" in part: assert "url" in part["image_url"] - if not img_util_imported and not logger.img_dependency_warned: + if not pil_imported and not logger.img_dependency_warned: logger.warning( "img_utils or PIL not imported. Skipping image token count." "Please install autogen with [lmm] option.", diff --git a/autogen/tools/dependency_injection.py b/autogen/tools/dependency_injection.py index 3cfe72e5d..a9aec1bcb 100644 --- a/autogen/tools/dependency_injection.py +++ b/autogen/tools/dependency_injection.py @@ -12,7 +12,7 @@ from fast_depends import inject from fast_depends.dependencies import model -from autogen.agentchat import Agent +from ..agentchat import Agent if TYPE_CHECKING: from ..agentchat.conversable_agent import ConversableAgent diff --git a/notebook/agentchat_realtime_swarm_webrtc.ipynb b/notebook/agentchat_realtime_swarm_webrtc.ipynb index 79aa00e05..7f53aab68 100644 --- a/notebook/agentchat_realtime_swarm_webrtc.ipynb +++ b/notebook/agentchat_realtime_swarm_webrtc.ipynb @@ -443,6 +443,7 @@ "\n", "app = FastAPI(lifespan=lifespan)\n", "\n", + "\n", "@app.get(\"/\", response_class=JSONResponse)\n", "async def index_page():\n", " return {\"message\": \"WebRTC AG2 Server is running!\"}" diff --git a/notebook/agentchat_realtime_websocket.ipynb b/notebook/agentchat_realtime_websocket.ipynb index 1f2e9054c..051ebaa21 100644 --- a/notebook/agentchat_realtime_websocket.ipynb +++ b/notebook/agentchat_realtime_websocket.ipynb @@ -184,6 +184,7 @@ "\n", "app = FastAPI(lifespan=lifespan)\n", "\n", + "\n", "@app.get(\"/\", response_class=JSONResponse)\n", "async def index_page():\n", " return {\"message\": \"WebSocket Audio Stream Server is running!\"}" diff --git a/test/agentchat/contrib/capabilities/test_image_generation_capability.py b/test/agentchat/contrib/capabilities/test_image_generation_capability.py index be09064d0..4002c1286 100644 --- a/test/agentchat/contrib/capabilities/test_image_generation_capability.py +++ b/test/agentchat/contrib/capabilities/test_image_generation_capability.py @@ -12,22 +12,20 @@ import pytest from autogen import code_utils +from autogen.agentchat.contrib.capabilities import generate_images +from autogen.agentchat.contrib.img_utils import get_pil_image from autogen.agentchat.conversable_agent import ConversableAgent from autogen.agentchat.user_proxy_agent import UserProxyAgent from autogen.cache.cache import Cache +from autogen.import_utils import optional_import_block, skip_on_missing_imports from autogen.oai import openai_utils -try: +from ....conftest import MOCK_OPEN_AI_API_KEY + +with optional_import_block() as result: from PIL import Image - from autogen.agentchat.contrib.capabilities import generate_images - from autogen.agentchat.contrib.img_utils import get_pil_image -except ImportError: - skip_requirement = True -else: - skip_requirement = False - -from ....conftest import MOCK_OPEN_AI_API_KEY +skip_requirement = not result.is_successful filter_dict = {"model": ["gpt-4o-mini"]} @@ -92,7 +90,7 @@ def image_gen_capability(): @pytest.mark.openai -@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.") +@skip_on_missing_imports("PIL", "unknown") def test_dalle_image_generator(dalle_config: dict[str, Any]): """Tests DalleImageGenerator capability to generate images by calling the OpenAI API.""" dalle_generator = dalle_image_generator(dalle_config, RESOLUTIONS[0], QUALITIES[0]) diff --git a/test/agentchat/contrib/capabilities/test_teachable_agent.py b/test/agentchat/contrib/capabilities/test_teachable_agent.py index e3285e27b..dd626ba4c 100755 --- a/test/agentchat/contrib/capabilities/test_teachable_agent.py +++ b/test/agentchat/contrib/capabilities/test_teachable_agent.py @@ -10,15 +10,14 @@ from autogen import ConversableAgent from autogen.formatting_utils import colored +from autogen.import_utils import optional_import_block from ....conftest import Credentials -try: +with optional_import_block() as result: from autogen.agentchat.contrib.capabilities.teachability import Teachability -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful # Specify the model to use by uncommenting one of the following lines. diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 64e415549..24f9ca8fc 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -17,6 +17,7 @@ TextMessageContentName, ) from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens +from autogen.import_utils import optional_import_block class _MockTextCompressor: @@ -104,12 +105,11 @@ def get_messages_with_names_post_filtered() -> list[dict]: def get_text_compressors() -> list[TextCompressor]: compressors: list[TextCompressor] = [_MockTextCompressor()] - try: + with optional_import_block() as result: from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua + if result.is_successful: compressors.append(LLMLingua()) - except ImportError: - pass return compressors diff --git a/test/agentchat/contrib/capabilities/test_vision_capability.py b/test/agentchat/contrib/capabilities/test_vision_capability.py index 765ec2ef9..3006b6786 100644 --- a/test/agentchat/contrib/capabilities/test_vision_capability.py +++ b/test/agentchat/contrib/capabilities/test_vision_capability.py @@ -10,15 +10,14 @@ import pytest from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.import_utils import optional_import_block -try: +with optional_import_block() as result: from PIL import Image # noqa: F401 from autogen.agentchat.contrib.capabilities.vision_capability import VisionCapability -except ImportError: - skip_test = True -else: - skip_test = False + +skip_test = not result.is_successful @pytest.fixture diff --git a/test/agentchat/contrib/graph_rag/test_falkor_graph_rag.py b/test/agentchat/contrib/graph_rag/test_falkor_graph_rag.py index d72ad34c0..030153164 100644 --- a/test/agentchat/contrib/graph_rag/test_falkor_graph_rag.py +++ b/test/agentchat/contrib/graph_rag/test_falkor_graph_rag.py @@ -9,16 +9,16 @@ import pytest from graphrag_sdk import Attribute, AttributeType, Entity, Ontology, Relation -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from autogen.agentchat.contrib.graph_rag.document import Document, DocumentType from autogen.agentchat.contrib.graph_rag.falkor_graph_query_engine import ( FalkorGraphQueryEngine, GraphStoreQueryResult, ) -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful reason = "do not run on MacOS or windows OR dependency is not installed" diff --git a/test/agentchat/contrib/graph_rag/test_native_neo4j_graph_rag.py b/test/agentchat/contrib/graph_rag/test_native_neo4j_graph_rag.py index 58962ccdc..12932c551 100644 --- a/test/agentchat/contrib/graph_rag/test_native_neo4j_graph_rag.py +++ b/test/agentchat/contrib/graph_rag/test_native_neo4j_graph_rag.py @@ -7,19 +7,20 @@ import pytest +from autogen.import_utils import optional_import_block + from ....conftest import reason -try: +with optional_import_block() as result: from autogen.agentchat.contrib.graph_rag.document import Document, DocumentType from autogen.agentchat.contrib.graph_rag.neo4j_native_graph_query_engine import ( GraphStoreQueryResult, Neo4jNativeGraphQueryEngine, ) -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful + # Configure the logging logging.basicConfig(level=logging.INFO) diff --git a/test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py b/test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py index 27c3c966b..22e8d49d5 100644 --- a/test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py +++ b/test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py @@ -10,19 +10,20 @@ import pytest +from autogen.import_utils import optional_import_block + from ....conftest import reason -try: +with optional_import_block() as result: from autogen.agentchat.contrib.graph_rag.document import Document, DocumentType from autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine import ( GraphStoreQueryResult, Neo4jGraphQueryEngine, ) -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful + # Configure the logging logging.basicConfig(level=logging.INFO) diff --git a/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py b/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py index 713a0c1f0..c79d9020d 100644 --- a/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py +++ b/test/agentchat/contrib/retrievechat/test_pgvector_retrievechat.py @@ -12,20 +12,18 @@ from sentence_transformers import SentenceTransformer from autogen import AssistantAgent +from autogen.import_utils import optional_import_block from ....conftest import Credentials -try: +with optional_import_block() as result: import pgvector # noqa: F401 from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, ) -except ImportError: - skip = True -else: - skip = False +skip = not result.is_successful test_dir = os.path.join(os.path.dirname(__file__), "../../..", "test_files") diff --git a/test/agentchat/contrib/retrievechat/test_qdrant_retrievechat.py b/test/agentchat/contrib/retrievechat/test_qdrant_retrievechat.py index 6f5e82292..696b20b80 100755 --- a/test/agentchat/contrib/retrievechat/test_qdrant_retrievechat.py +++ b/test/agentchat/contrib/retrievechat/test_qdrant_retrievechat.py @@ -12,10 +12,11 @@ import pytest from autogen import AssistantAgent +from autogen.import_utils import optional_import_block from ....conftest import Credentials -try: +with optional_import_block() as result: import fastembed # noqa: F401 from qdrant_client import QdrantClient @@ -25,16 +26,12 @@ query_qdrant, ) - QDRANT_INSTALLED = True -except ImportError: - QDRANT_INSTALLED = False +QDRANT_INSTALLED = result.is_successful -try: +with optional_import_block() as result: import openai # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/agentchat/contrib/retrievechat/test_retrievechat.py b/test/agentchat/contrib/retrievechat/test_retrievechat.py index 143e633d1..e40e97d16 100755 --- a/test/agentchat/contrib/retrievechat/test_retrievechat.py +++ b/test/agentchat/contrib/retrievechat/test_retrievechat.py @@ -10,9 +10,11 @@ import pytest +from autogen.import_utils import optional_import_block + from ....conftest import Credentials, reason -try: +with optional_import_block() as result: import chromadb import openai # noqa: F401 from chromadb.utils import embedding_functions as ef @@ -21,10 +23,8 @@ from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( RetrieveUserProxyAgent, ) -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason diff --git a/test/agentchat/contrib/test_agent_builder.py b/test/agentchat/contrib/test_agent_builder.py index a4ae12231..c816fb95a 100755 --- a/test/agentchat/contrib/test_agent_builder.py +++ b/test/agentchat/contrib/test_agent_builder.py @@ -12,16 +12,15 @@ import pytest from autogen.agentchat.contrib.captainagent.agent_builder import AgentBuilder +from autogen.import_utils import optional_import_block from ...conftest import KEY_LOC, OAI_CONFIG_LIST -try: +with optional_import_block() as result: import chromadb # noqa: F401 import huggingface_hub # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful here = os.path.abspath(os.path.dirname(__file__)) diff --git a/test/agentchat/contrib/test_captainagent.py b/test/agentchat/contrib/test_captainagent.py index 18f589c19..66ab74edd 100644 --- a/test/agentchat/contrib/test_captainagent.py +++ b/test/agentchat/contrib/test_captainagent.py @@ -7,16 +7,15 @@ from autogen import UserProxyAgent from autogen.agentchat.contrib.captainagent.captainagent import CaptainAgent +from autogen.import_utils import optional_import_block from ...conftest import KEY_LOC, OAI_CONFIG_LIST, Credentials, reason -try: +with optional_import_block() as result: import chromadb # noqa: F401 import huggingface_hub # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/agentchat/contrib/test_img_utils.py b/test/agentchat/contrib/test_img_utils.py index bb4518bd4..b6fcf9b14 100755 --- a/test/agentchat/contrib/test_img_utils.py +++ b/test/agentchat/contrib/test_img_utils.py @@ -14,7 +14,9 @@ import pytest import requests -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: import numpy as np from PIL import Image @@ -28,10 +30,8 @@ message_formatter_pil_to_b64, num_tokens_from_gpt_image, ) -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful base64_encoded_image = ( diff --git a/test/agentchat/contrib/test_llamaindex_conversable_agent.py b/test/agentchat/contrib/test_llamaindex_conversable_agent.py index a7bb83aae..b20e9fe5d 100644 --- a/test/agentchat/contrib/test_llamaindex_conversable_agent.py +++ b/test/agentchat/contrib/test_llamaindex_conversable_agent.py @@ -13,21 +13,17 @@ from autogen import GroupChat, GroupChatManager from autogen.agentchat.contrib.llamaindex_conversable_agent import LLamaIndexConversableAgent from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.import_utils import optional_import_block -from ...conftest import MOCK_OPEN_AI_API_KEY, reason +from ...conftest import MOCK_OPEN_AI_API_KEY -skip_reasons = [reason] -try: +with optional_import_block() as result: from llama_index.core.agent import ReActAgent from llama_index.core.chat_engine.types import AgentChatResponse from llama_index.llms.openai import OpenAI - skip_for_dependencies = False - skip_reason = "" -except ImportError as e: - skip_for_dependencies = True - skip_reason = f"dependency not installed: {e.msg}" - pass +skip_for_dependencies = not result.is_successful +skip_reason = "" if result.is_successful else "dependency not installed" openai_key = MOCK_OPEN_AI_API_KEY diff --git a/test/agentchat/contrib/test_llava.py b/test/agentchat/contrib/test_llava.py index 7e967d5a2..38041203a 100755 --- a/test/agentchat/contrib/test_llava.py +++ b/test/agentchat/contrib/test_llava.py @@ -11,15 +11,14 @@ import pytest +from autogen.import_utils import optional_import_block + from ...conftest import MOCK_OPEN_AI_API_KEY -try: +with optional_import_block() as result: from autogen.agentchat.contrib.llava_agent import LLaVAAgent, _llava_call_binary_with_config, llava_call -except ImportError: - skip = True -else: - skip = False +skip = not result.is_successful @pytest.mark.skipif(skip, reason="dependency is not installed") diff --git a/test/agentchat/contrib/test_lmm.py b/test/agentchat/contrib/test_lmm.py index de3987004..a65962d3d 100755 --- a/test/agentchat/contrib/test_lmm.py +++ b/test/agentchat/contrib/test_lmm.py @@ -13,16 +13,15 @@ import autogen from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.import_utils import optional_import_block from ...conftest import MOCK_OPEN_AI_API_KEY -try: +with optional_import_block() as result: from autogen.agentchat.contrib.img_utils import get_pil_image from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful base64_encoded_image = ( diff --git a/test/agentchat/contrib/test_reasoning_agent.py b/test/agentchat/contrib/test_reasoning_agent.py index db9a80ddf..3aeb76e6c 100644 --- a/test/agentchat/contrib/test_reasoning_agent.py +++ b/test/agentchat/contrib/test_reasoning_agent.py @@ -14,20 +14,16 @@ import pytest from autogen.agentchat.contrib.reasoning_agent import ReasoningAgent, ThinkNode, visualize_tree +from autogen.import_utils import optional_import_block sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) -from ...conftest import reason -skip_reasons = [reason] -try: + +with optional_import_block() as result: from graphviz import Digraph # noqa: F401 - skip_for_dependencies = False - skip_reason = "" -except ImportError as e: - skip_for_dependencies = True - skip_reason = f"dependency not installed: {e.msg}" - pass +skip_for_dependencies = not result.is_successful +skip_reason = "" if result.is_successful else "dependency not installed" here = os.path.abspath(os.path.dirname(__file__)) diff --git a/test/agentchat/contrib/test_web_surfer.py b/test/agentchat/contrib/test_web_surfer.py index 0eac02bca..edd9c38d0 100755 --- a/test/agentchat/contrib/test_web_surfer.py +++ b/test/agentchat/contrib/test_web_surfer.py @@ -12,6 +12,7 @@ import pytest from autogen import UserProxyAgent +from autogen.import_utils import optional_import_block from ...conftest import MOCK_OPEN_AI_API_KEY, Credentials @@ -19,12 +20,11 @@ BLOG_POST_TITLE = "Does Model and Inference Parameter Matter in LLM Applications? - A Case Study for MATH - AG2" BING_QUERY = "Microsoft" -try: +with optional_import_block() as result: from autogen.agentchat.contrib.web_surfer import WebSurferAgent -except ImportError: - skip_all = True -else: - skip_all = False + +skip_all = not result.is_successful + try: BING_API_KEY = os.environ["BING_API_KEY"] diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index c9193576d..3e2e1ca0a 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -9,18 +9,18 @@ import pytest +from autogen.import_utils import optional_import_block + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -try: +with optional_import_block() as result: import chromadb import chromadb.errors import sentence_transformers # noqa: F401 from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.skipif(skip, reason="dependency is not installed") diff --git a/test/agentchat/contrib/vectordb/test_mongodb.py b/test/agentchat/contrib/vectordb/test_mongodb.py index 6de1418c7..66853fe8e 100644 --- a/test/agentchat/contrib/vectordb/test_mongodb.py +++ b/test/agentchat/contrib/vectordb/test_mongodb.py @@ -12,13 +12,15 @@ import pytest from autogen.agentchat.contrib.vectordb.base import Document +from autogen.import_utils import optional_import_block -try: +with optional_import_block() as result: import pymongo # noqa: F401 import sentence_transformers # noqa: F401 from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB -except ImportError: + +if not result.is_successful: # To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true logger = logging.getLogger(__name__) logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]") diff --git a/test/agentchat/contrib/vectordb/test_pgvectordb.py b/test/agentchat/contrib/vectordb/test_pgvectordb.py index 0fe572bc5..9083ad954 100644 --- a/test/agentchat/contrib/vectordb/test_pgvectordb.py +++ b/test/agentchat/contrib/vectordb/test_pgvectordb.py @@ -10,18 +10,18 @@ import pytest +from autogen.import_utils import optional_import_block + from ....conftest import reason -try: +with optional_import_block() as result: import pgvector # noqa: F401 import psycopg import sentence_transformers # noqa: F401 from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason diff --git a/test/agentchat/contrib/vectordb/test_qdrant.py b/test/agentchat/contrib/vectordb/test_qdrant.py index c8b3b9fdb..c7dd1385b 100644 --- a/test/agentchat/contrib/vectordb/test_qdrant.py +++ b/test/agentchat/contrib/vectordb/test_qdrant.py @@ -9,18 +9,18 @@ import pytest +from autogen.import_utils import optional_import_block + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -try: +with optional_import_block() as result: import uuid from qdrant_client import QdrantClient from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.skipif(skip, reason="dependency is not installed") diff --git a/test/agentchat/test_cache_agent.py b/test/agentchat/test_cache_agent.py index 61a504048..6349d08f2 100644 --- a/test/agentchat/test_cache_agent.py +++ b/test/agentchat/test_cache_agent.py @@ -13,22 +13,19 @@ import autogen from autogen.agentchat import AssistantAgent, UserProxyAgent from autogen.cache import Cache +from autogen.import_utils import optional_import_block from ..conftest import Credentials -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip_tests = True -else: - skip_tests = False -try: +skip_tests = not result.is_successful + +with optional_import_block() as result: import redis # noqa: F401 -except ImportError: - skip_redis_tests = True -else: - skip_redis_tests = False + +skip_redis_tests = not result.is_successful @pytest.mark.openai diff --git a/test/agentchat/test_function_call.py b/test/agentchat/test_function_call.py index 95cc7c422..10ecccd47 100755 --- a/test/agentchat/test_function_call.py +++ b/test/agentchat/test_function_call.py @@ -13,16 +13,15 @@ import pytest import autogen +from autogen.import_utils import optional_import_block from autogen.math_utils import eval_math_responses from ..conftest import Credentials, reason -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/agentchat/test_math_user_proxy_agent.py b/test/agentchat/test_math_user_proxy_agent.py index d1491d0b9..a5781bc94 100755 --- a/test/agentchat/test_math_user_proxy_agent.py +++ b/test/agentchat/test_math_user_proxy_agent.py @@ -15,15 +15,14 @@ _add_print_to_last_line, _remove_print, ) +from autogen.import_utils import optional_import_block from ..conftest import Credentials -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/agentchat/test_tool_calls.py b/test/agentchat/test_tool_calls.py index 83db77998..905475481 100755 --- a/test/agentchat/test_tool_calls.py +++ b/test/agentchat/test_tool_calls.py @@ -13,17 +13,16 @@ import pytest import autogen +from autogen.import_utils import optional_import_block from autogen.math_utils import eval_math_responses from autogen.oai.client import TOOL_ENABLED from ..conftest import Credentials -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/cache/test_cache.py b/test/cache/test_cache.py index 2635c05e8..4ecf0fa56 100755 --- a/test/cache/test_cache.py +++ b/test/cache/test_cache.py @@ -9,15 +9,13 @@ import unittest from unittest.mock import ANY, MagicMock, patch -try: - from azure.cosmos import CosmosClient +from autogen.cache.cache import Cache +from autogen.import_utils import optional_import_block - skip_azure = False -except ImportError: - CosmosClient = object - skip_azure = True +with optional_import_block() as result: + from azure.cosmos import CosmosClient -from autogen.cache.cache import Cache +skip_azure = not result.is_successful class TestCache(unittest.TestCase): diff --git a/test/cache/test_cosmos_db_cache.py b/test/cache/test_cosmos_db_cache.py index 5f24bc80d..268237629 100644 --- a/test/cache/test_cosmos_db_cache.py +++ b/test/cache/test_cosmos_db_cache.py @@ -10,16 +10,14 @@ import unittest from unittest.mock import MagicMock, patch -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from azure.cosmos.exceptions import CosmosResourceNotFoundError from autogen.cache.cosmos_db_cache import CosmosDBCache - skip_test = False -except ImportError: - CosmosResourceNotFoundError = Exception - CosmosDBCache = object - skip_test = True +skip_test = not result.is_successful class TestCosmosDBCache(unittest.TestCase): diff --git a/test/cache/test_redis_cache.py b/test/cache/test_redis_cache.py index e135e2d50..16a1efb09 100755 --- a/test/cache/test_redis_cache.py +++ b/test/cache/test_redis_cache.py @@ -12,12 +12,14 @@ import pytest -try: - from autogen.cache.redis_cache import RedisCache +from autogen.cache.redis_cache import RedisCache +from autogen.import_utils import optional_import_block - skip_redis_tests = False -except ImportError: - skip_redis_tests = True +with optional_import_block() as result: + import redis # noqa: F401 + + +skip_redis_tests = not result.is_successful @pytest.mark.redis diff --git a/test/coding/test_embedded_ipython_code_executor.py b/test/coding/test_embedded_ipython_code_executor.py index e1e37bdfd..dd8e0ff29 100644 --- a/test/coding/test_embedded_ipython_code_executor.py +++ b/test/coding/test_embedded_ipython_code_executor.py @@ -20,6 +20,7 @@ ) from autogen.coding.base import CodeBlock, CodeExecutor from autogen.coding.factory import CodeExecutorFactory +from autogen.import_utils import optional_import_block from ..conftest import MOCK_OPEN_AI_API_KEY @@ -28,7 +29,8 @@ else: skip_docker_test = False -try: +classes_to_test = [] +with optional_import_block() as result: from autogen.coding.jupyter import ( DockerJupyterServer, EmbeddedIPythonCodeExecutor, @@ -36,6 +38,8 @@ LocalJupyterServer, ) +if result.is_successful: + class DockerJupyterExecutor(JupyterCodeExecutor): def __init__(self, **kwargs): jupyter_server = DockerJupyterServer() @@ -55,12 +59,12 @@ def __init__(self, **kwargs): if not skip_docker_test: classes_to_test.append(DockerJupyterExecutor) - skip = False - skip_reason = "" -except ImportError as e: - skip = True - skip_reason = "Dependencies for EmbeddedIPythonCodeExecutor or LocalJupyterCodeExecutor not installed. " + e.msg - classes_to_test = [] +skip = not result.is_successful +skip_reason = ( + "" + if result.is_successful + else "Dependencies for EmbeddedIPythonCodeExecutor or LocalJupyterCodeExecutor not installed." +) @pytest.mark.parametrize("cls", classes_to_test) diff --git a/test/coding/test_user_defined_functions.py b/test/coding/test_user_defined_functions.py index a81e30049..ec4134f6f 100644 --- a/test/coding/test_user_defined_functions.py +++ b/test/coding/test_user_defined_functions.py @@ -9,16 +9,15 @@ import pytest from autogen.coding.base import CodeBlock +from autogen.coding.func_with_reqs import FunctionWithRequirements, with_requirements from autogen.coding.local_commandline_code_executor import LocalCommandLineCodeExecutor +from autogen.import_utils import optional_import_block -try: +with optional_import_block() as result: import pandas -except ImportError: - skip = True -else: - skip = False -from autogen.coding.func_with_reqs import FunctionWithRequirements, with_requirements +skip = not result.is_successful + classes_to_test = [LocalCommandLineCodeExecutor] diff --git a/test/oai/test_anthropic.py b/test/oai/test_anthropic.py index c6e3d5bb0..4b15cd32d 100644 --- a/test/oai/test_anthropic.py +++ b/test/oai/test_anthropic.py @@ -10,17 +10,14 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from anthropic.types import Message, TextBlock from autogen.oai.anthropic import AnthropicClient, _calculate_cost - skip = False -except ImportError: - AnthropicClient = object - Message = object - TextBlock = object - skip = True +skip = not result.is_successful from typing import List diff --git a/test/oai/test_bedrock.py b/test/oai/test_bedrock.py index 54a553f62..cb4eb2986 100644 --- a/test/oai/test_bedrock.py +++ b/test/oai/test_bedrock.py @@ -8,14 +8,14 @@ import pytest -try: - from autogen.oai.bedrock import BedrockClient, oai_messages_to_bedrock_messages - - skip = False -except ImportError: - BedrockClient = object - InternalServerError = object - skip = True +from autogen.import_utils import optional_import_block +from autogen.oai.bedrock import BedrockClient, oai_messages_to_bedrock_messages + +with optional_import_block() as result: + import boto3 # noqa: F401 + from botocore.config import Config # noqa: F401 + +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_cerebras.py b/test/oai/test_cerebras.py index 3cfe50fc8..7b134f573 100644 --- a/test/oai/test_cerebras.py +++ b/test/oai/test_cerebras.py @@ -8,14 +8,13 @@ import pytest -try: - from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost - - skip = False -except ImportError: - CerebrasClient = object - InternalServerError = object - skip = True +from autogen.import_utils import optional_import_block +from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost + +with optional_import_block() as result: + from cerebras.cloud.sdk import Cerebras, Stream # noqa: F401 + +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_client.py b/test/oai/test_client.py index ae969f37e..5f8217c18 100755 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -15,21 +15,21 @@ from autogen import OpenAIWrapper from autogen.cache.cache import Cache +from autogen.import_utils import optional_import_block from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED, OpenAIClient from ..conftest import Credentials TOOL_ENABLED = False -try: + +with optional_import_block() as result: import openai from openai import OpenAI # noqa: F401 if openai.__version__ >= "1.1.0": TOOL_ENABLED = True -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful @pytest.mark.openai diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index e3b444a4f..9150703ce 100755 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -12,15 +12,12 @@ import pytest from autogen import OpenAIWrapper +from autogen.import_utils import optional_import_block from ..conftest import Credentials, reason -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip = True -else: - skip = False # raises exception if openai>=1 is installed and something is wrong with imports # otherwise the test will be skipped @@ -31,6 +28,8 @@ ChoiceDeltaToolCallFunction, ) +skip = not result.is_successful + @pytest.mark.openai @pytest.mark.skipif(skip, reason=reason) diff --git a/test/oai/test_cohere.py b/test/oai/test_cohere.py index 0e40eff72..bdb8b2c79 100644 --- a/test/oai/test_cohere.py +++ b/test/oai/test_cohere.py @@ -10,13 +10,12 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from autogen.oai.cohere import CohereClient, calculate_cohere_cost - skip = False -except ImportError: - CohereClient = object - skip = True +skip = not result.is_successful reason = "Cohere dependency not installed!" diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 322508480..63c5be543 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -8,13 +8,13 @@ import pytest from autogen import OpenAIWrapper +from autogen.import_utils import optional_import_block -try: +with optional_import_block() as result: from openai import OpenAI # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful + TEST_COST = 20000000 TEST_CUSTOM_RESPONSE = "This is a custom response." diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 497c084ba..27d6f3292 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -10,8 +10,11 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel + +from autogen.import_utils import optional_import_block -try: +with optional_import_block() as result: import google.auth # noqa: F401 from google.api_core.exceptions import InternalServerError from google.auth.credentials import Credentials @@ -24,17 +27,7 @@ from autogen.oai.gemini import GeminiClient - skip = False -except ImportError: - GeminiClient = object - VertexAIHarmBlockThreshold = object - VertexAIHarmCategory = object - VertexAISafetySetting = object - vertexai_global_config = object - InternalServerError = object - skip = True - -from pydantic import BaseModel +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_groq.py b/test/oai/test_groq.py index 1ca97ecdb..396040816 100644 --- a/test/oai/test_groq.py +++ b/test/oai/test_groq.py @@ -8,14 +8,12 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from autogen.oai.groq import GroqClient, calculate_groq_cost - skip = False -except ImportError: - GroqClient = object - InternalServerError = object - skip = True +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_mistral.py b/test/oai/test_mistral.py index a43e2b518..2aeac48c9 100644 --- a/test/oai/test_mistral.py +++ b/test/oai/test_mistral.py @@ -8,7 +8,9 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from mistralai import ( AssistantMessage, # noqa: F401 Function, # noqa: F401 @@ -22,11 +24,7 @@ from autogen.oai.mistral import MistralAIClient, calculate_mistral_cost - skip = False -except ImportError: - MistralAIClient = object - InternalServerError = object - skip = True +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_ollama.py b/test/oai/test_ollama.py index 89fa55670..41f7882bc 100644 --- a/test/oai/test_ollama.py +++ b/test/oai/test_ollama.py @@ -8,17 +8,17 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel -try: - from autogen.oai.ollama import OllamaClient, response_to_tool_call +from autogen.import_utils import optional_import_block +from autogen.oai.ollama import OllamaClient, response_to_tool_call - skip = False -except ImportError: - OllamaClient = object - InternalServerError = object - skip = True +with optional_import_block() as result: + import ollama # noqa: F401 + from fix_busted_json import repair_json # noqa: F401 + from ollama import Client # noqa: F401 -from pydantic import BaseModel +skip = not result.is_successful # Fixtures for mock data diff --git a/test/oai/test_together.py b/test/oai/test_together.py index b19d3ee8c..876b2c2e1 100644 --- a/test/oai/test_together.py +++ b/test/oai/test_together.py @@ -8,16 +8,14 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # noqa: F401 from autogen.oai.together import TogetherClient, calculate_together_cost - skip = False -except ImportError: - TogetherClient = object - InternalServerError = object - skip = True +skip = not result.is_successful # Fixtures for mock data diff --git a/test/test_code_utils.py b/test/test_code_utils.py index dbf93a949..25fd9c144 100755 --- a/test/test_code_utils.py +++ b/test/test_code_utils.py @@ -30,6 +30,7 @@ infer_lang, is_docker_running, ) +from autogen.import_utils import optional_import_block from .conftest import Credentials @@ -390,9 +391,9 @@ def test_create_virtual_env_with_extra_args(): def _test_improve(credentials_all: Credentials): - try: + with optional_import_block() as result: import openai # noqa: F401 - except ImportError: + if not result.is_successful: return config_list = credentials_all.config_list improved, _ = improve_function( diff --git a/test/test_notebook.py b/test/test_notebook.py index e3268d289..adad2a7f1 100755 --- a/test/test_notebook.py +++ b/test/test_notebook.py @@ -11,12 +11,12 @@ import pytest -try: +from autogen.import_utils import optional_import_block + +with optional_import_block() as result: import openai # noqa: F401 -except ImportError: - skip = True -else: - skip = False + +skip = not result.is_successful here = os.path.abspath(os.path.dirname(__file__)) @@ -24,7 +24,7 @@ def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=False): import nbformat - from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor + from nbconvert.preprocessors import ExecutePreprocessor try: nb_loc = os.path.join(here, os.pardir, "notebook") @@ -44,8 +44,6 @@ def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=Fals nb_output_file.write(output["text"].strip() + "\n") elif "data" in output and "text/plain" in output["data"]: nb_output_file.write(output["data"]["text/plain"].strip() + "\n") - except CellExecutionError: - raise finally: if save: with open(os.path.join(here, output_nb), "w", encoding="utf-8") as nb_executed_file: diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index 774b4163c..2c957f7e7 100755 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -8,34 +8,36 @@ """Unit test for retrieve_utils.py""" +import os + import pytest -try: +from autogen.import_utils import optional_import_block +from autogen.retrieve_utils import ( + create_vector_db_from_dir, + extract_text_from_pdf, + get_files_from_dir, + is_url, + parse_html_to_markdown, + query_vector_db, + split_files_to_chunks, + split_text_to_chunks, +) +from autogen.token_count_utils import count_token + +with optional_import_block() as result: + import bs4 # noqa: F401 import chromadb + import markdownify # noqa: F401 + import pypdf # noqa: F401 - from autogen.retrieve_utils import ( - create_vector_db_from_dir, - extract_text_from_pdf, - get_files_from_dir, - is_url, - parse_html_to_markdown, - query_vector_db, - split_files_to_chunks, - split_text_to_chunks, - ) - from autogen.token_count_utils import count_token -except ImportError: - skip = True -else: - skip = False -import os -try: +skip = not result.is_successful + +with optional_import_block() as result: from unstructured.partition.auto import partition # noqa: F401 - HAS_UNSTRUCTURED = True -except ImportError: - HAS_UNSTRUCTURED = False +HAS_UNSTRUCTURED = result.is_successful test_dir = os.path.join(os.path.dirname(__file__), "test_files") expected_text = """AutoGen is an advanced tool designed to assist developers in harnessing the capabilities @@ -136,9 +138,9 @@ def test_query_vector_db(self): assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) def test_custom_vector_db(self): - try: + with optional_import_block() as result: import lancedb - except ImportError: + if not result.is_successful: return from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent diff --git a/test/test_token_count.py b/test/test_token_count.py index 8dfa7bf83..8461be8fb 100755 --- a/test/test_token_count.py +++ b/test/test_token_count.py @@ -8,12 +8,12 @@ import pytest -try: - from autogen.agentchat.contrib.img_utils import num_tokens_from_gpt_image # noqa: F401 +from autogen.import_utils import optional_import_block - img_util_imported = True -except ImportError: - img_util_imported = False +with optional_import_block() as result: + from PIL import Image # noqa: F401 + +img_util_imported = result.is_successful from autogen.token_count_utils import (