Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Jan 21, 2025
2 parents 7195e36 + 20c0128 commit a2805e7
Show file tree
Hide file tree
Showing 88 changed files with 565 additions and 506 deletions.
3 changes: 1 addition & 2 deletions autogen/agentchat/assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
# SPDX-License-Identifier: MIT
from typing import Callable, Literal, Optional, Union

from autogen.runtime_logging import log_new_agent, logging_enabled

from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent


Expand Down
27 changes: 16 additions & 11 deletions autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import re
from typing import Any, Literal, Optional, Protocol, Union

from PIL.Image import Image
from openai import OpenAI

from autogen import Agent, ConversableAgent, code_utils
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
from autogen.cache import AbstractCache
from .... import Agent, ConversableAgent, code_utils
from ....cache import AbstractCache
from ....import_utils import optional_import_block, require_optional_import
from .. import img_utils
from ..capabilities.agent_capability import AgentCapability
from ..text_analyzer_agent import TextAnalyzerAgent

with optional_import_block():
from PIL.Image import Image

SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
Expand All @@ -34,7 +37,7 @@ class ImageGenerator(Protocol):
NOTE: Current implementation does not allow you to edit a previously existing image.
"""

def generate_image(self, prompt: str) -> Image:
def generate_image(self, prompt: str) -> "Image":
"""Generates an image based on the provided prompt.
Args:
Expand Down Expand Up @@ -62,6 +65,7 @@ def cache_key(self, prompt: str) -> str:
...


@require_optional_import("PIL", "unknown")
class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.
Expand Down Expand Up @@ -94,7 +98,7 @@ def __init__(
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])

def generate_image(self, prompt: str) -> Image:
def generate_image(self, prompt: str) -> "Image":
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
Expand All @@ -114,6 +118,7 @@ def cache_key(self, prompt: str) -> str:
return ",".join([str(k) for k in keys])


@require_optional_import("PIL", "unknown")
class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
Expand Down Expand Up @@ -253,15 +258,15 @@ def _extract_prompt(self, last_message) -> str:
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)

def _cache_get(self, prompt: str) -> Optional[Image]:
def _cache_get(self, prompt: str) -> Optional["Image"]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)

if cached_value:
return img_utils.get_pil_image(cached_value)

def _cache_set(self, prompt: str, image: Image):
def _cache_set(self, prompt: str, image: "Image"):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))
Expand All @@ -272,7 +277,7 @@ def _extract_analysis(self, analysis: Union[str, dict, None]) -> str:
else:
return code_utils.content_str(analysis)

def _generate_content_message(self, prompt: str, image: Image) -> dict[str, Any]:
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
Expand Down
24 changes: 18 additions & 6 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from typing import Union

import requests
from PIL import Image

from ...import_utils import optional_import_block, require_optional_import

with optional_import_block():
from PIL import Image

from autogen.agentchat import utils

Expand All @@ -37,7 +41,8 @@
}


def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
@require_optional_import("PIL", "unknown")
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
"""Loads an image from a file and returns a PIL Image object.
Parameters:
Expand Down Expand Up @@ -75,7 +80,8 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
return image.convert("RGB")


def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
@require_optional_import("PIL", "unknown")
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
Expand Down Expand Up @@ -105,6 +111,7 @@ def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
return content


@require_optional_import("PIL", "unknown")
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
Expand Down Expand Up @@ -149,7 +156,8 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str,
return new_prompt, images


def pil_to_data_uri(image: Image.Image) -> str:
@require_optional_import("PIL", "unknown")
def pil_to_data_uri(image: "Image.Image") -> str:
"""Converts a PIL Image object to a data URI.
Parameters:
Expand Down Expand Up @@ -184,6 +192,7 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri


@require_optional_import("PIL", "unknown")
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict]]:
"""Formats the input prompt by replacing image tags and returns a list of text and images.
Expand Down Expand Up @@ -251,7 +260,8 @@ def extract_img_paths(paragraph: str) -> list:
return img_paths


def _to_pil(data: str) -> Image.Image:
@require_optional_import("PIL", "unknown")
def _to_pil(data: str) -> "Image.Image":
"""Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
Expand All @@ -266,6 +276,7 @@ def _to_pil(data: str) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(data)))


@require_optional_import("PIL", "unknown")
def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]:
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
Expand Down Expand Up @@ -321,8 +332,9 @@ def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]:
return new_messages


@require_optional_import("PIL", "unknown")
def num_tokens_from_gpt_image(
image_data: Union[str, Image.Image], model: str = "gpt-4-vision", low_quality: bool = False
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
) -> int:
"""Calculate the number of tokens required to process an image based on its dimensions
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from fastapi.websockets import WebSocket
from ..websockets import WebSocketProtocol as WebSocket


LOG_EVENT_TYPES = [
Expand Down Expand Up @@ -142,5 +142,5 @@ async def initialize_session(self) -> None:

if TYPE_CHECKING:

def twilio_audio_adapter(websocket: WebSocket) -> RealtimeObserver:
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return TwilioAudioAdapter(websocket)
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
from typing import TYPE_CHECKING, Optional

from ..realtime_events import AudioDelta, RealtimeEvent, SpeechStarted
from ..realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from fastapi.websockets import WebSocket


from ..realtime_observer import RealtimeObserver
from ..websockets import WebSocketProtocol as WebSocket

LOG_EVENT_TYPES = [
"error",
Expand Down Expand Up @@ -135,5 +133,5 @@ async def run_loop(self) -> None:

if TYPE_CHECKING:

def websocket_audio_adapter(websocket: WebSocket) -> RealtimeObserver:
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return WebSocketAudioAdapter(websocket)
2 changes: 1 addition & 1 deletion autogen/agentchat/realtime_agent/clients/oai/rtc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .utils import parse_oai_message

if TYPE_CHECKING:
from ...websockets import WebSocketProtocol as WebSocket
from ..realtime_client import RealtimeClientProtocol
from ..websockets import WebSocketProtocol as WebSocket

__all__ = ["OpenAIRealtimeWebRTCClient"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Protocol, runtime_checkable
from typing import Any, AsyncIterator, Protocol, runtime_checkable

__all__ = ["WebSocketProtocol"]

Expand All @@ -16,3 +16,5 @@ async def send_json(self, data: Any, mode: str = "text") -> None: ...
async def receive_json(self, mode: str = "text") -> Any: ...

async def receive_text(self) -> str: ...

def iter_text(self) -> AsyncIterator[str]: ...
6 changes: 0 additions & 6 deletions autogen/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations

import sys
from types import TracebackType
from typing import Any

from .abstract_cache_base import AbstractCache
from .cache_factory import CacheFactory

if sys.version_info >= (3, 11):
pass
else:
pass


class Cache(AbstractCache):
"""A wrapper class for managing cache configuration and instances.
Expand Down
14 changes: 9 additions & 5 deletions autogen/cache/cosmos_db_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,24 @@
import pickle
from typing import Any, Optional, TypedDict, Union

from azure.cosmos import CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosResourceNotFoundError
from ..import_utils import optional_import_block, require_optional_import
from .abstract_cache_base import AbstractCache

from autogen.cache.abstract_cache_base import AbstractCache
with optional_import_block():
from azure.cosmos import CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosResourceNotFoundError


@require_optional_import("azure", "cosmosdb")
class CosmosDBConfig(TypedDict, total=False):
connection_string: str
database_id: str
container_id: str
cache_seed: Optional[Union[str, int]]
client: Optional[CosmosClient]
client: Optional["CosmosClient"]


@require_optional_import("azure", "cosmosdb")
class CosmosDBCache(AbstractCache):
"""Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
Expand Down Expand Up @@ -75,7 +79,7 @@ def from_connection_string(cls, seed: Union[str, int], connection_string: str, d
return cls(str(seed), config)

@classmethod
def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str):
def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str):
config = {"client": client, "database_id": database_id, "container_id": container_id}
return cls(str(seed), config)

Expand Down
11 changes: 7 additions & 4 deletions autogen/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
from types import TracebackType
from typing import Any, Optional, Union

import redis

from .abstract_cache_base import AbstractCache

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from ..import_utils import optional_import_block, require_optional_import
from .abstract_cache_base import AbstractCache

with optional_import_block():
import redis


@require_optional_import("redis", "redis")
class RedisCache(AbstractCache):
"""Implementation of AbstractCache using the Redis database.
Expand Down
3 changes: 1 addition & 2 deletions autogen/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import docker

from autogen import oai

from . import oai
from .types import UserMessageImageContentPart, UserMessageTextContentPart

SENTINEL = object()
Expand Down
4 changes: 1 addition & 3 deletions autogen/coding/jupyter/docker_jupyter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

import docker

from ..docker_commandline_code_executor import _wait_for_ready

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


from ..docker_commandline_code_executor import _wait_for_ready
from .base import JupyterConnectable, JupyterConnectionInfo
from .jupyter_client import JupyterClient

Expand Down
8 changes: 6 additions & 2 deletions autogen/coding/jupyter/embedded_ipython_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
from queue import Empty
from typing import Any

from jupyter_client import KernelManager # type: ignore[attr-defined]
from jupyter_client.kernelspec import KernelSpecManager
from pydantic import BaseModel, Field, field_validator

from ...import_utils import optional_import_block, require_optional_import
from ..base import CodeBlock, CodeExtractor, IPythonCodeResult
from ..markdown_code_extractor import MarkdownCodeExtractor

with optional_import_block():
from jupyter_client import KernelManager # type: ignore[attr-defined]
from jupyter_client.kernelspec import KernelSpecManager

__all__ = "EmbeddedIPythonCodeExecutor"


@require_optional_import("jupyter_client", "jupyter-executor")
class EmbeddedIPythonCodeExecutor(BaseModel):
"""(Experimental) A code executor class that executes code statefully using an embedded
IPython kernel managed by this class.
Expand Down
Loading

0 comments on commit a2805e7

Please sign in to comment.