Skip to content

Commit

Permalink
Merge pull request #33 from microsoft/dev/python-39-backport
Browse files Browse the repository at this point in the history
Backporting to add Python 3.9 + 3.10 compatibility
  • Loading branch information
t-schn authored May 30, 2024
2 parents e906fed + fd4d54b commit 24c5dc4
Show file tree
Hide file tree
Showing 18 changed files with 87 additions and 57 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sammo"
version = "0.1.6"
version = "0.1.7"
description = "A flexible, easy-to-use library for running and optimizing prompts for Large Language Models (LLMs)."
authors = ["Tobias Schnabel"]
license = "MIT"
Expand All @@ -12,7 +12,7 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.11,<3.12"
python = "^3.9,<3.12"
beartype = "^0.15"
benepar = {version = "^0.2", optional = true}
filelock = "^3.12"
Expand All @@ -30,6 +30,8 @@ xmltodict = "^0.13"
PyYAML = "^6.0"
aiohttp = "^3.6"
diskcache = "^5.2"
quattro = "^24"
async-timeout = "^4.0.3"

[tool.poetry.extras]
parser = ["benepar"]
Expand Down
5 changes: 3 additions & 2 deletions sammo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import logging
import beartype
from beartype.typing import Union
import sammo.utils as utils
from pathlib import Path

Expand All @@ -10,9 +11,9 @@

@beartype.beartype
def setup_logger(
default_level: int | str = "DEBUG",
default_level: Union[int, str] = "DEBUG",
log_prompts_to_file: bool = False,
prompt_level: int | str = "DEBUG",
prompt_level: Union[int, str] = "DEBUG",
prompt_logfile_name: str = None,
) -> logging.Logger:
if log_prompts_to_file:
Expand Down
7 changes: 4 additions & 3 deletions sammo/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import abc
import copy
import re

from beartype.typing import Callable, Self
from beartype.typing import Callable, Any
from frozendict import frozendict
import pyglove as pg
import pybars
Expand Down Expand Up @@ -194,7 +195,7 @@ def __init__(self, query, child_selector=None):
self.child_selector = child_selector

@classmethod
def from_path(cls, path_descriptor: str | dict | Self):
def from_path(cls, path_descriptor: str | dict | Any):
if isinstance(path_descriptor, CompiledQuery):
return path_descriptor
elif isinstance(path_descriptor, str):
Expand Down Expand Up @@ -236,7 +237,7 @@ class Component:

NEEDS_SCHEDULING = False

def __init__(self, child: Self | str, name: str | None = None):
def __init__(self, child: Any | str, name: str | None = None):
if name is None:
self._name = self.__class__.__name__
else:
Expand Down
11 changes: 9 additions & 2 deletions sammo/compactbars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Provides a way of displaying multiple progress bars in a single line. Works in both interactive and non-interactive
environments.
"""
from __future__ import annotations
import collections
import datetime
import io
Expand All @@ -12,6 +13,7 @@
import sys
import time
from beartype import beartype
from beartype.typing import Union

from sammo import utils

Expand Down Expand Up @@ -191,7 +193,7 @@ class CompactProgressBars:
:param refresh_interval: The minimum time interval between display refreshes.
"""

def __init__(self, width: int | None = None, refresh_interval: float = 1 / 50):
def __init__(self, width: Union[int, None] = None, refresh_interval: float = 1 / 50):
self._bars = collections.OrderedDict()
self._printer = LinePrinter()
self._last_update = 0
Expand All @@ -215,7 +217,12 @@ def _should_refresh(self) -> bool:
return False

def get(
self, id: str, total: int | None = None, position: int | None = None, display_name: str | None = None, **kwargs
self,
id: str,
total: Union[int, None] = None,
position: Union[int, None] = None,
display_name: Union[str, None] = None,
**kwargs,
) -> SubProgressBar:
"""
Gets existing or creates a new progress bar given an id.
Expand Down
21 changes: 12 additions & 9 deletions sammo/components.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import asyncio
import quattro
import logging
import math
import warnings

import beartype
from beartype.typing import Callable, Literal
from beartype.typing import Union as TUnion
from frozendict import frozendict

import sammo.utils as utils
Expand Down Expand Up @@ -49,8 +52,8 @@ def __init__(
self,
child: ScalarComponent,
name=None,
system_prompt: str | None = None,
history: ScalarComponent | None = None,
system_prompt: TUnion[str, None] = None,
history: TUnion[ScalarComponent, None] = None,
seed=0,
randomness: float = 0,
max_tokens=None,
Expand Down Expand Up @@ -162,7 +165,7 @@ async def _call(self, runner: Runner, context: dict, dynamic_context: frozendict
if not isinstance(collection, list):
collection = [collection]

async with asyncio.TaskGroup() as tg:
async with quattro.TaskGroup() as tg:
for x in collection:
tasks.append(
tg.create_task(
Expand Down Expand Up @@ -235,10 +238,10 @@ def __init__(
def run(
self,
runner: Runner,
data: DataTable | list | None = None,
progress_callback: Callable | bool = True,
data: TUnion[DataTable, list, None] = None,
progress_callback: TUnion[Callable, bool] = True,
priority: int = 0,
on_error: Literal["raise", "empty_result", "backoff"] | None = None,
on_error: TUnion[Literal["raise", "empty_result", "backoff"], None] = None,
) -> DataTable:
"""Synchronous version of `arun`."""
return utils.sync(self.arun(runner, data, progress_callback, priority, on_error))
Expand All @@ -254,10 +257,10 @@ def n_minibatches(self, table: DataTable) -> int:
async def arun(
self,
runner: Runner,
data: DataTable | list | None = None,
progress_callback: Callable | bool = True,
data: TUnion[DataTable, list, None] = None,
progress_callback: TUnion[Callable, bool] = True,
priority: int = 0,
on_error: Literal["raise", "empty_result", "backoff"] | None = None,
on_error: TUnion[Literal["raise", "empty_result", "backoff"], None] = None,
):
"""
Run the component asynchronously and return a DataTable with the results.
Expand Down
27 changes: 14 additions & 13 deletions sammo/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
DataTables are the primary data structure used in SAMMO.
They are essentially a wrapper around a list of inputs and outputs (labels), with some additional functionality.
"""
from __future__ import annotations
import copy
import hashlib
import math

from beartype import beartype
from beartype.typing import Callable, Iterator, Self
from beartype.typing import Callable, Iterator, Union
import more_itertools
import orjson
import pyglove as pg
Expand All @@ -36,8 +37,8 @@ class DataTable(pg.JSONConvertible):
def __init__(
self,
inputs: list,
outputs: list | None = None,
constants: dict | None = None,
outputs: Union[list, None] = None,
constants: Union[dict, None] = None,
seed=42,
):
inputs = DataTable._ensure_list(inputs)
Expand Down Expand Up @@ -73,7 +74,7 @@ def outputs(self):
return self._outputs

@property
def constants(self) -> dict | None:
def constants(self) -> Union[dict, None]:
"""Access constants."""
return self._data["constants"]

Expand Down Expand Up @@ -101,9 +102,9 @@ def from_json(cls, json_value, **kwargs):
def from_pandas(
cls,
df: "pandas.DataFrame",
output_fields: list[str] | str = "output",
input_fields: list[str] | str | None = None,
constants: dict | None = None,
output_fields: Union[list[str], str] = "output",
input_fields: Union[list[str], str, None] = None,
constants: Union[dict, None] = None,
seed=42,
):
"""Create a DataTable from a pandas DataFrame.
Expand All @@ -130,8 +131,8 @@ def _slice_to_explicit_idx(self, key: slice):
def from_records(
cls,
records: list[dict],
output_fields: list[str] | str = "output",
input_fields: list[str] | str | None = None,
output_fields: Union[list[str], str] = "output",
input_fields: Union[list[str], str, None] = None,
**kwargs,
):
if len(records) == 0:
Expand Down Expand Up @@ -195,7 +196,7 @@ def to_string(self, max_rows: int = 10, max_col_width: int = 60, max_cell_length
table = "<empty DataTable>"
return f"{table}\nConstants: {DataTable._truncate(self.constants, max_col_width)}"

def _to_explicit_idx(self, key: int | slice | list[int]):
def _to_explicit_idx(self, key: Union[int, slice, list[int]]):
if isinstance(key, int):
return [key]
elif isinstance(key, slice):
Expand All @@ -208,7 +209,7 @@ def __getitem__(self, key):
new_outputs = [self._data["outputs"][i] for i in idx]
return DataTable(new_inputs, new_outputs, self.constants, self._seed)

def sample(self, k: int, seed: int | None = None) -> Self:
def sample(self, k: int, seed: Union[int, None] = None):
"""Sample rows without replacement.
:param k: Number of rows to sample.
Expand All @@ -224,7 +225,7 @@ def sample(self, k: int, seed: int | None = None) -> Self:

return self[selected_idx]

def shuffle(self, seed: int | None = None) -> Self:
def shuffle(self, seed: Union[int, None] = None):
"""Shuffle rows.
:param seed: Random seed. If not provided, instance seed is used.
Expand All @@ -242,7 +243,7 @@ def random_split(self, *sizes: int, seed=None) -> tuple:
splits = [slice(sum(sizes[:i]), sum(sizes[: i + 1])) for i in range(len(sizes))]
return tuple(sampled[split] for split in splits)

def copy(self) -> Self:
def copy(self):
return copy.deepcopy(self)

def get_minibatch_iterator(self, minibatch_size):
Expand Down
1 change: 1 addition & 0 deletions sammo/dataformatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
get_extractor method that can be used to parse the LLM responses in this format.
"""
from __future__ import annotations
import collections
import json

Expand Down
1 change: 1 addition & 0 deletions sammo/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Common formats such as JSON, XML, or Markdown are supported and require no data format specification. If validation is
required, it should happen downstream of the extraction step.
"""
from __future__ import annotations
import abc
import ast
import fractions
Expand Down
1 change: 1 addition & 0 deletions sammo/instructions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import hashlib
import math

Expand Down
4 changes: 3 additions & 1 deletion sammo/mutators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import abc
import asyncio
import quattro
import collections
import logging
import random
Expand Down Expand Up @@ -838,7 +840,7 @@ async def mutate(
else:
selected_mutators = collections.Counter(selected_mutators)
tasks = list()
async with asyncio.TaskGroup() as tg:
async with quattro.TaskGroup() as tg:
for i, (mut, n_mut) in enumerate(selected_mutators.items()):
mut.objective = self._objective
tasks.append(tg.create_task(mut.mutate(candidate, data, runner, n_mut, random_state + i)))
Expand Down
17 changes: 9 additions & 8 deletions sammo/runners.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import abc
import base64
import re
import warnings
from abc import abstractmethod
import asyncio
import async_timeout
from collections.abc import MutableMapping
import json
import logging
Expand All @@ -16,7 +18,7 @@
import aiohttp
import orjson
from beartype import beartype
from beartype.typing import Literal
from beartype.typing import Literal, Union

from sammo import PROMPT_LOGGER_NAME
from sammo.base import LLMResult, Costs, Runner
Expand Down Expand Up @@ -48,7 +50,6 @@ async def generate_text(self, prompt: str, *args, **kwargs):
return LLMResult(self.return_value)


@beartype
class BaseRunner(Runner):
"""Base class for OpenAI API runners.
Expand All @@ -73,13 +74,13 @@ def __init__(
self,
model_id: str,
api_config: dict | str | pathlib.Path,
cache: None | MutableMapping | str | os.PathLike = None,
equivalence_class: str | Literal["major", "exact"] = "major",
rate_limit: AtMost | list[AtMost] | Throttler | int = 2,
cache: Union[None, MutableMapping, str, os.PathLike] = None,
equivalence_class: Union[str, Literal["major", "exact"]] = "major",
rate_limit: Union[AtMost, list[AtMost], Throttler, int] = 2,
max_retries: int = 50,
max_context_window: int | None = None,
max_context_window: Union[int, None] = None,
retry: bool = True,
timeout: float | int = 60,
timeout: Union[float, int] = 60,
max_timeout_retries: int = 1,
use_cached_timeouts: bool = True,
):
Expand Down Expand Up @@ -150,7 +151,7 @@ async def _execute_request(self, request, fingerprint, priority=0):

try:
job_handle = await self._throttler.wait_in_line(priority)
async with asyncio.timeout(self._timeout):
async with async_timeout.timeout(self._timeout):
json = await self._call_backend(request)
response_obj = self._llm_result(request, json, fingerprint)
response_obj.retries = cur_try
Expand Down
2 changes: 1 addition & 1 deletion sammo/runners_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from asyncio import TaskGroup
from quattro import TaskGroup
from unittest.mock import AsyncMock, MagicMock

import pytest
Expand Down
Loading

0 comments on commit 24c5dc4

Please sign in to comment.