Skip to content

Commit

Permalink
Fix tests (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
keepingitneil authored Oct 4, 2024
1 parent 9792ca4 commit c1c3157
Showing 1 changed file with 24 additions and 31 deletions.
55 changes: 24 additions & 31 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_hashable_typeinfo():
hash(typeinfo)


LLMS: list[llm.LLM | Callable[[], llm.LLM]] = [
openai.LLM(),
LLMS: list[Callable[[], llm.LLM]] = [
lambda: openai.LLM(),
lambda: openai.beta.AssistantLLM(
assistant_opts=openai.beta.AssistantOptions(
create_options=openai.beta.AssistantCreateOptions(
Expand All @@ -103,10 +103,9 @@ def test_hashable_typeinfo():
]


@pytest.mark.parametrize("input_llm", LLMS)
async def test_chat(input_llm: llm.LLM | Callable[[], llm.LLM]):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_chat(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
chat_ctx = ChatContext().append(
text='You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.'
)
Expand All @@ -128,10 +127,9 @@ async def test_chat(input_llm: llm.LLM | Callable[[], llm.LLM]):
assert len(text) > 0


@pytest.mark.parametrize("input_llm", LLMS)
async def test_basic_fnc_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_basic_fnc_calls(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
Expand All @@ -145,10 +143,9 @@ async def test_basic_fnc_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
assert len(calls) == 2, "get_weather should be called twice"


@pytest.mark.parametrize("input_llm", LLMS)
async def test_runtime_addition(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_runtime_addition(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
called_msg = ""

Expand All @@ -169,10 +166,9 @@ async def show_message(
assert called_msg == "Hello LiveKit!", "send_message should be called"


@pytest.mark.parametrize("input_llm", LLMS)
async def test_cancelled_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_cancelled_calls(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
Expand All @@ -190,10 +186,9 @@ async def test_cancelled_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
), "toggle_light should have been cancelled"


@pytest.mark.parametrize("input_llm", LLMS)
async def test_calls_arrays(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
Expand All @@ -216,10 +211,9 @@ async def test_calls_arrays(input_llm: Callable[[], llm.LLM] | llm.LLM):
), "select_currencies should have eur, gbp, sek"


@pytest.mark.parametrize("input_llm", LLMS)
async def test_calls_choices(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_calls_choices(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
Expand All @@ -234,14 +228,13 @@ async def test_calls_choices(input_llm: Callable[[], llm.LLM] | llm.LLM):
assert volume == 30, "change_volume should have been called with volume 30"


@pytest.mark.parametrize("input_llm", LLMS)
async def test_optional_args(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_optional_args(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
input_llm, "Can you update my information? My name is Theo", fnc_ctx
input_llm, "Using a tool call update the user info to name Theo", fnc_ctx
)

calls = stream.execute_functions()
Expand Down

0 comments on commit c1c3157

Please sign in to comment.