Skip to content

Commit

Permalink
feat(openai): add llm provider and system attributes (#1082)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 28, 2024
1 parent 32756ed commit 232c031
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from abc import ABC
from contextlib import contextmanager
from itertools import chain
from types import ModuleType
from typing import Any, Awaitable, Callable, Iterable, Iterator, Mapping, Tuple

from httpx import URL
from opentelemetry import context as context_api
from opentelemetry import trace as trace_api
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
Expand All @@ -30,7 +32,12 @@
_io_value_and_type,
)
from openinference.instrumentation.openai._with_span import _WithSpan
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from openinference.semconv.trace import (
OpenInferenceLLMProviderValues,
OpenInferenceLLMSystemValues,
OpenInferenceSpanKindValues,
SpanAttributes,
)

__all__ = (
"_Request",
Expand Down Expand Up @@ -115,12 +122,21 @@ def _get_span_kind(self, cast_to: type) -> str:
else OpenInferenceSpanKindValues.LLM.value
)

def _get_attributes_from_instance(self, instance: Any) -> Iterator[Tuple[str, AttributeValue]]:
if not isinstance(base_url := getattr(instance, "base_url", None), URL):
return
if base_url.host.endswith("api.openai.com"):
yield SpanAttributes.LLM_PROVIDER, OpenInferenceLLMProviderValues.OPENAI.value
if base_url.host.endswith("openai.azure.com"):
yield SpanAttributes.LLM_PROVIDER, OpenInferenceLLMProviderValues.AZURE.value

def _get_attributes_from_request(
self,
cast_to: type,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
yield SpanAttributes.OPENINFERENCE_SPAN_KIND, self._get_span_kind(cast_to=cast_to)
yield SpanAttributes.LLM_SYSTEM, OpenInferenceLLMSystemValues.OPENAI.value
try:
yield from _as_input_attributes(
_io_value_and_type(request_parameters),
Expand Down Expand Up @@ -259,9 +275,12 @@ def __call__(
return wrapped(*args, **kwargs)
with self._start_as_current_span(
span_name=span_name,
attributes=self._get_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
attributes=chain(
self._get_attributes_from_instance(instance),
self._get_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
),
),
context_attributes=get_attributes_from_context(),
extra_attributes=self._get_extra_attributes_from_request(
Expand Down Expand Up @@ -313,9 +332,12 @@ async def __call__(
return await wrapped(*args, **kwargs)
with self._start_as_current_span(
span_name=span_name,
attributes=self._get_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
attributes=chain(
self._get_attributes_from_instance(instance),
self._get_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
),
),
context_attributes=get_attributes_from_context(),
extra_attributes=self._get_extra_attributes_from_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Union,
cast,
)
from urllib.parse import urljoin

import pytest
from httpx import AsyncByteStream, Response
Expand All @@ -36,6 +37,8 @@
ImageAttributes,
MessageAttributes,
MessageContentAttributes,
OpenInferenceLLMProviderValues,
OpenInferenceLLMSystemValues,
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
Expand All @@ -48,13 +51,24 @@
logger.handlers.clear()
logger.addHandler(logging.StreamHandler())

_OPENAI_BASE_URL = "https://api.openai.com/v1/"
_AZURE_BASE_URL = "https://aoairesource.openai.azure.com"


@pytest.mark.parametrize(
"base_url",
(
pytest.param(_OPENAI_BASE_URL, id="openai-base-url"),
pytest.param(_AZURE_BASE_URL, id="azure-base-url"),
),
)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_raw", [False, True])
@pytest.mark.parametrize("is_stream", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
@pytest.mark.parametrize("use_context_attributes", [False, True])
def test_chat_completions(
base_url: str,
is_async: bool,
is_raw: bool,
is_stream: bool,
Expand Down Expand Up @@ -83,7 +97,7 @@ def test_chat_completions(
"temperature": random.random(),
"n": len(output_messages),
}
url = "https://api.openai.com/v1/chat/completions"
url = urljoin(base_url, "chat/completions")
respx_kwargs: Dict[str, Any] = {
**(
{"stream": MockAsyncByteStream(chat_completion_mock_stream[0])}
Expand All @@ -104,9 +118,9 @@ def test_chat_completions(
create_kwargs = {"messages": input_messages, **invocation_parameters}
openai = import_module("openai")
completions = (
openai.AsyncOpenAI(api_key="sk-").chat.completions
openai.AsyncOpenAI(api_key="sk-", base_url=base_url).chat.completions
if is_async
else openai.OpenAI(api_key="sk-").chat.completions
else openai.OpenAI(api_key="sk-", base_url=base_url).chat.completions
)
create = completions.with_raw_response.create if is_raw else completions.create

Expand Down Expand Up @@ -176,6 +190,10 @@ async def task() -> None:
assert event.name == "exception"
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value
assert attributes.pop(LLM_PROVIDER, None) == (
LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE
)
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
assert isinstance(attributes.pop(INPUT_VALUE, None), str)
assert (
OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None))
Expand Down Expand Up @@ -221,12 +239,20 @@ async def task() -> None:
assert attributes == {} # test should account for all span attributes


@pytest.mark.parametrize(
"base_url",
(
pytest.param(_OPENAI_BASE_URL, id="openai-base-url"),
pytest.param(_AZURE_BASE_URL, id="azure-base-url"),
),
)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_raw", [False, True])
@pytest.mark.parametrize("is_stream", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
@pytest.mark.parametrize("use_context_attributes", [False, True])
def test_completions(
base_url: str,
is_async: bool,
is_raw: bool,
is_stream: bool,
Expand All @@ -253,7 +279,7 @@ def test_completions(
"temperature": random.random(),
"n": len(output_texts),
}
url = "https://api.openai.com/v1/completions"
url = urljoin(base_url, "completions")
respx_kwargs: Dict[str, Any] = {
**(
{"stream": MockAsyncByteStream(completion_mock_stream[0])}
Expand All @@ -274,9 +300,9 @@ def test_completions(
create_kwargs = {"prompt": prompt, **invocation_parameters}
openai = import_module("openai")
completions = (
openai.AsyncOpenAI(api_key="sk-").completions
openai.AsyncOpenAI(api_key="sk-", base_url=base_url).completions
if is_async
else openai.OpenAI(api_key="sk-").completions
else openai.OpenAI(api_key="sk-", base_url=base_url).completions
)
create = completions.with_raw_response.create if is_raw else completions.create

Expand Down Expand Up @@ -331,6 +357,10 @@ async def task() -> None:
assert event.name == "exception"
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value
assert attributes.pop(LLM_PROVIDER, None) == (
LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE
)
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
assert (
json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None)))
== invocation_parameters
Expand Down Expand Up @@ -365,12 +395,20 @@ async def task() -> None:
assert attributes == {} # test should account for all span attributes


@pytest.mark.parametrize(
"base_url",
(
pytest.param(_OPENAI_BASE_URL, id="openai-base-url"),
pytest.param(_AZURE_BASE_URL, id="azure-base-url"),
),
)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_raw", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
@pytest.mark.parametrize("encoding_format", ["float", "base64"])
@pytest.mark.parametrize("input_text", ["hello", ["hello", "world"]])
def test_embeddings(
base_url: str,
is_async: bool,
is_raw: bool,
encoding_format: str,
Expand All @@ -390,7 +428,7 @@ def test_embeddings(
"total_tokens": random.randint(10, 100),
}
output_embeddings = [("AACAPwAAAEA=", (1.0, 2.0)), ((2.0, 3.0), (2.0, 3.0))]
url = "https://api.openai.com/v1/embeddings"
url = urljoin(base_url, "embeddings")
respx_mock.post(url).mock(
return_value=Response(
status_code=status_code,
Expand All @@ -408,9 +446,9 @@ def test_embeddings(
create_kwargs = {"input": input_text, **invocation_parameters}
openai = import_module("openai")
completions = (
openai.AsyncOpenAI(api_key="sk-").embeddings
openai.AsyncOpenAI(api_key="sk-", base_url=base_url).embeddings
if is_async
else openai.OpenAI(api_key="sk-").embeddings
else openai.OpenAI(api_key="sk-", base_url=base_url).embeddings
)
create = completions.with_raw_response.create if is_raw else completions.create
with suppress(openai.BadRequestError):
Expand Down Expand Up @@ -442,6 +480,10 @@ async def task() -> None:
assert (
attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.EMBEDDING.value
)
assert attributes.pop(LLM_PROVIDER, None) == (
LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE
)
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
assert (
json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None)))
== invocation_parameters
Expand Down Expand Up @@ -576,6 +618,8 @@ async def task() -> None:
assert event.name == "exception"
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value
assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
assert isinstance(attributes.pop(INPUT_VALUE, None), str)
assert (
OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None))
Expand Down Expand Up @@ -689,6 +733,8 @@ def test_chat_completions_with_config_hiding_hiding_inputs(
assert not span.status.description
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value
assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
if hide_inputs:
assert attributes.pop(INPUT_VALUE, None) == REDACTED_VALUE
else:
Expand Down Expand Up @@ -792,6 +838,8 @@ def test_chat_completions_with_config_hiding_hiding_outputs(
assert not span.status.description
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value
assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI
assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI
assert isinstance(attributes.pop(INPUT_VALUE, None), str)
assert (
OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None))
Expand Down Expand Up @@ -1300,3 +1348,8 @@ def tool_call_function_arguments(prefix: str, i: int, j: int) -> str:
USER_ID = SpanAttributes.USER_ID
METADATA = SpanAttributes.METADATA
TAG_TAGS = SpanAttributes.TAG_TAGS
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
LLM_PROVIDER_OPENAI = OpenInferenceLLMProviderValues.OPENAI.value
LLM_PROVIDER_AZURE = OpenInferenceLLMProviderValues.AZURE.value
LLM_SYSTEM_OPENAI = OpenInferenceLLMSystemValues.OPENAI.value

0 comments on commit 232c031

Please sign in to comment.