From 232c031deea6b81a007447a37db2b0e5eb8ce613 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Mon, 28 Oct 2024 16:44:06 -0700 Subject: [PATCH] feat(openai): add llm provider and system attributes (#1082) --- .../instrumentation/openai/_request.py | 36 ++++++++-- .../openai/test_instrumentor.py | 71 ++++++++++++++++--- 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py index 8489ace73..c30045f49 100644 --- a/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py +++ b/python/instrumentation/openinference-instrumentation-openai/src/openinference/instrumentation/openai/_request.py @@ -1,9 +1,11 @@ import logging from abc import ABC from contextlib import contextmanager +from itertools import chain from types import ModuleType from typing import Any, Awaitable, Callable, Iterable, Iterator, Mapping, Tuple +from httpx import URL from opentelemetry import context as context_api from opentelemetry import trace as trace_api from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY @@ -30,7 +32,12 @@ _io_value_and_type, ) from openinference.instrumentation.openai._with_span import _WithSpan -from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes +from openinference.semconv.trace import ( + OpenInferenceLLMProviderValues, + OpenInferenceLLMSystemValues, + OpenInferenceSpanKindValues, + SpanAttributes, +) __all__ = ( "_Request", @@ -115,12 +122,21 @@ def _get_span_kind(self, cast_to: type) -> str: else OpenInferenceSpanKindValues.LLM.value ) + def _get_attributes_from_instance(self, instance: Any) -> Iterator[Tuple[str, AttributeValue]]: + if not isinstance(base_url := getattr(instance, "base_url", None), URL): + return + if base_url.host.endswith("api.openai.com"): + yield SpanAttributes.LLM_PROVIDER, OpenInferenceLLMProviderValues.OPENAI.value + if base_url.host.endswith("openai.azure.com"): + yield SpanAttributes.LLM_PROVIDER, OpenInferenceLLMProviderValues.AZURE.value + def _get_attributes_from_request( self, cast_to: type, request_parameters: Mapping[str, Any], ) -> Iterator[Tuple[str, AttributeValue]]: yield SpanAttributes.OPENINFERENCE_SPAN_KIND, self._get_span_kind(cast_to=cast_to) + yield SpanAttributes.LLM_SYSTEM, OpenInferenceLLMSystemValues.OPENAI.value try: yield from _as_input_attributes( _io_value_and_type(request_parameters), @@ -259,9 +275,12 @@ def __call__( return wrapped(*args, **kwargs) with self._start_as_current_span( span_name=span_name, - attributes=self._get_attributes_from_request( - cast_to=cast_to, - request_parameters=request_parameters, + attributes=chain( + self._get_attributes_from_instance(instance), + self._get_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), ), context_attributes=get_attributes_from_context(), extra_attributes=self._get_extra_attributes_from_request( @@ -313,9 +332,12 @@ async def __call__( return await wrapped(*args, **kwargs) with self._start_as_current_span( span_name=span_name, - attributes=self._get_attributes_from_request( - cast_to=cast_to, - request_parameters=request_parameters, + attributes=chain( + self._get_attributes_from_instance(instance), + self._get_attributes_from_request( + cast_to=cast_to, + request_parameters=request_parameters, + ), ), context_attributes=get_attributes_from_context(), extra_attributes=self._get_extra_attributes_from_request( diff --git a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py index 3aad7f733..61e5af613 100644 --- a/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-openai/tests/openinference/instrumentation/openai/test_instrumentor.py @@ -20,6 +20,7 @@ Union, cast, ) +from urllib.parse import urljoin import pytest from httpx import AsyncByteStream, Response @@ -36,6 +37,8 @@ ImageAttributes, MessageAttributes, MessageContentAttributes, + OpenInferenceLLMProviderValues, + OpenInferenceLLMSystemValues, OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, @@ -48,13 +51,24 @@ logger.handlers.clear() logger.addHandler(logging.StreamHandler()) +_OPENAI_BASE_URL = "https://api.openai.com/v1/" +_AZURE_BASE_URL = "https://aoairesource.openai.azure.com" + +@pytest.mark.parametrize( + "base_url", + ( + pytest.param(_OPENAI_BASE_URL, id="openai-base-url"), + pytest.param(_AZURE_BASE_URL, id="azure-base-url"), + ), +) @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_raw", [False, True]) @pytest.mark.parametrize("is_stream", [False, True]) @pytest.mark.parametrize("status_code", [200, 400]) @pytest.mark.parametrize("use_context_attributes", [False, True]) def test_chat_completions( + base_url: str, is_async: bool, is_raw: bool, is_stream: bool, @@ -83,7 +97,7 @@ def test_chat_completions( "temperature": random.random(), "n": len(output_messages), } - url = "https://api.openai.com/v1/chat/completions" + url = urljoin(base_url, "chat/completions") respx_kwargs: Dict[str, Any] = { **( {"stream": MockAsyncByteStream(chat_completion_mock_stream[0])} @@ -104,9 +118,9 @@ def test_chat_completions( create_kwargs = {"messages": input_messages, **invocation_parameters} openai = import_module("openai") completions = ( - openai.AsyncOpenAI(api_key="sk-").chat.completions + openai.AsyncOpenAI(api_key="sk-", base_url=base_url).chat.completions if is_async - else openai.OpenAI(api_key="sk-").chat.completions + else openai.OpenAI(api_key="sk-", base_url=base_url).chat.completions ) create = completions.with_raw_response.create if is_raw else completions.create @@ -176,6 +190,10 @@ async def task() -> None: assert event.name == "exception" attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert attributes.pop(LLM_PROVIDER, None) == ( + LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE + ) + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI assert isinstance(attributes.pop(INPUT_VALUE, None), str) assert ( OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None)) @@ -221,12 +239,20 @@ async def task() -> None: assert attributes == {} # test should account for all span attributes +@pytest.mark.parametrize( + "base_url", + ( + pytest.param(_OPENAI_BASE_URL, id="openai-base-url"), + pytest.param(_AZURE_BASE_URL, id="azure-base-url"), + ), +) @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_raw", [False, True]) @pytest.mark.parametrize("is_stream", [False, True]) @pytest.mark.parametrize("status_code", [200, 400]) @pytest.mark.parametrize("use_context_attributes", [False, True]) def test_completions( + base_url: str, is_async: bool, is_raw: bool, is_stream: bool, @@ -253,7 +279,7 @@ def test_completions( "temperature": random.random(), "n": len(output_texts), } - url = "https://api.openai.com/v1/completions" + url = urljoin(base_url, "completions") respx_kwargs: Dict[str, Any] = { **( {"stream": MockAsyncByteStream(completion_mock_stream[0])} @@ -274,9 +300,9 @@ def test_completions( create_kwargs = {"prompt": prompt, **invocation_parameters} openai = import_module("openai") completions = ( - openai.AsyncOpenAI(api_key="sk-").completions + openai.AsyncOpenAI(api_key="sk-", base_url=base_url).completions if is_async - else openai.OpenAI(api_key="sk-").completions + else openai.OpenAI(api_key="sk-", base_url=base_url).completions ) create = completions.with_raw_response.create if is_raw else completions.create @@ -331,6 +357,10 @@ async def task() -> None: assert event.name == "exception" attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert attributes.pop(LLM_PROVIDER, None) == ( + LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE + ) + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI assert ( json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None))) == invocation_parameters @@ -365,12 +395,20 @@ async def task() -> None: assert attributes == {} # test should account for all span attributes +@pytest.mark.parametrize( + "base_url", + ( + pytest.param(_OPENAI_BASE_URL, id="openai-base-url"), + pytest.param(_AZURE_BASE_URL, id="azure-base-url"), + ), +) @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_raw", [False, True]) @pytest.mark.parametrize("status_code", [200, 400]) @pytest.mark.parametrize("encoding_format", ["float", "base64"]) @pytest.mark.parametrize("input_text", ["hello", ["hello", "world"]]) def test_embeddings( + base_url: str, is_async: bool, is_raw: bool, encoding_format: str, @@ -390,7 +428,7 @@ def test_embeddings( "total_tokens": random.randint(10, 100), } output_embeddings = [("AACAPwAAAEA=", (1.0, 2.0)), ((2.0, 3.0), (2.0, 3.0))] - url = "https://api.openai.com/v1/embeddings" + url = urljoin(base_url, "embeddings") respx_mock.post(url).mock( return_value=Response( status_code=status_code, @@ -408,9 +446,9 @@ def test_embeddings( create_kwargs = {"input": input_text, **invocation_parameters} openai = import_module("openai") completions = ( - openai.AsyncOpenAI(api_key="sk-").embeddings + openai.AsyncOpenAI(api_key="sk-", base_url=base_url).embeddings if is_async - else openai.OpenAI(api_key="sk-").embeddings + else openai.OpenAI(api_key="sk-", base_url=base_url).embeddings ) create = completions.with_raw_response.create if is_raw else completions.create with suppress(openai.BadRequestError): @@ -442,6 +480,10 @@ async def task() -> None: assert ( attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.EMBEDDING.value ) + assert attributes.pop(LLM_PROVIDER, None) == ( + LLM_PROVIDER_OPENAI if base_url.startswith(_OPENAI_BASE_URL) else LLM_PROVIDER_AZURE + ) + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI assert ( json.loads(cast(str, attributes.pop(LLM_INVOCATION_PARAMETERS, None))) == invocation_parameters @@ -576,6 +618,8 @@ async def task() -> None: assert event.name == "exception" attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI assert isinstance(attributes.pop(INPUT_VALUE, None), str) assert ( OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None)) @@ -689,6 +733,8 @@ def test_chat_completions_with_config_hiding_hiding_inputs( assert not span.status.description attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI if hide_inputs: assert attributes.pop(INPUT_VALUE, None) == REDACTED_VALUE else: @@ -792,6 +838,8 @@ def test_chat_completions_with_config_hiding_hiding_outputs( assert not span.status.description attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) assert attributes.pop(OPENINFERENCE_SPAN_KIND, None) == OpenInferenceSpanKindValues.LLM.value + assert attributes.pop(LLM_PROVIDER, None) == LLM_PROVIDER_OPENAI + assert attributes.pop(LLM_SYSTEM, None) == LLM_SYSTEM_OPENAI assert isinstance(attributes.pop(INPUT_VALUE, None), str) assert ( OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE, None)) @@ -1300,3 +1348,8 @@ def tool_call_function_arguments(prefix: str, i: int, j: int) -> str: USER_ID = SpanAttributes.USER_ID METADATA = SpanAttributes.METADATA TAG_TAGS = SpanAttributes.TAG_TAGS +LLM_PROVIDER = SpanAttributes.LLM_PROVIDER +LLM_SYSTEM = SpanAttributes.LLM_SYSTEM +LLM_PROVIDER_OPENAI = OpenInferenceLLMProviderValues.OPENAI.value +LLM_PROVIDER_AZURE = OpenInferenceLLMProviderValues.AZURE.value +LLM_SYSTEM_OPENAI = OpenInferenceLLMSystemValues.OPENAI.value