From 29b0ac9d643c47bb7c6fd62d4cf581dd8157291c Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Tue, 12 Mar 2024 17:40:14 -0400 Subject: [PATCH] fix: Add token counts to bedrock instrumentation (#270) --- .../instrumentation/bedrock/__init__.py | 39 ++++++++++++- .../tests/test_instrumentor.py | 56 ++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/python/instrumentation/openinference-instrumentation-bedrock/src/openinference/instrumentation/bedrock/__init__.py b/python/instrumentation/openinference-instrumentation-bedrock/src/openinference/instrumentation/bedrock/__init__.py index 330a42e7d..09ffe5946 100644 --- a/python/instrumentation/openinference-instrumentation-bedrock/src/openinference/instrumentation/bedrock/__init__.py +++ b/python/instrumentation/openinference-instrumentation-bedrock/src/openinference/instrumentation/bedrock/__init__.py @@ -9,7 +9,10 @@ from botocore.response import StreamingBody from openinference.instrumentation.bedrock.package import _instruments from openinference.instrumentation.bedrock.version import __version__ -from openinference.semconv.trace import MessageAttributes, SpanAttributes +from openinference.semconv.trace import ( + OpenInferenceSpanKindValues, + SpanAttributes, +) from opentelemetry import context as context_api from opentelemetry import trace as trace_api from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY @@ -83,7 +86,14 @@ def _invocation_wrapper(wrapped_client: InstrumentedClient) -> Callable[..., Any @wraps(wrapped_client.invoke_model) def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]: + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return wrapped_client._unwrapped_invoke_model(*args, **kwargs) # type: ignore + with tracer.start_as_current_span("bedrock.invoke_model") as span: + span.set_attribute( + SpanAttributes.OPENINFERENCE_SPAN_KIND, + OpenInferenceSpanKindValues.LLM.value, + ) response = wrapped_client._unwrapped_invoke_model(*args, **kwargs) response["body"] = BufferedStreamingBody( response["body"]._raw_stream, response["body"]._content_length @@ -96,11 +106,34 @@ def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]: prompt = request_body.pop("prompt") invocation_parameters = json.dumps(request_body) - _set_span_attribute(span, SpanAttributes.LLM_PROMPTS, prompt) + _set_span_attribute(span, SpanAttributes.INPUT_VALUE, prompt) _set_span_attribute( span, SpanAttributes.LLM_INVOCATION_PARAMETERS, invocation_parameters ) + if metadata := response.get("ResponseMetadata"): + if headers := metadata.get("HTTPHeaders"): + if input_token_count := headers.get("x-amzn-bedrock-input-token-count"): + input_token_count = int(input_token_count) + _set_span_attribute( + span, SpanAttributes.LLM_TOKEN_COUNT_PROMPT, input_token_count + ) + if response_token_count := headers.get("x-amzn-bedrock-output-token-count"): + response_token_count = int(response_token_count) + _set_span_attribute( + span, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, + response_token_count, + ) + if total_token_count := ( + input_token_count + response_token_count + if input_token_count and response_token_count + else None + ): + _set_span_attribute( + span, SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_token_count + ) + if model_id := kwargs.get("modelId"): _set_span_attribute(span, SpanAttributes.LLM_MODEL_NAME, model_id) @@ -117,7 +150,7 @@ def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]: content = "" if content: - _set_span_attribute(span, MessageAttributes.MESSAGE_CONTENT, content) + _set_span_attribute(span, SpanAttributes.OUTPUT_VALUE, content) return response # type: ignore diff --git a/python/instrumentation/openinference-instrumentation-bedrock/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-bedrock/tests/test_instrumentor.py index 209c39c5d..655383f59 100644 --- a/python/instrumentation/openinference-instrumentation-bedrock/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-bedrock/tests/test_instrumentor.py @@ -8,6 +8,9 @@ import pytest from botocore.response import StreamingBody from openinference.instrumentation.bedrock import BedrockInstrumentor +from openinference.semconv.trace import ( + OpenInferenceSpanKindValues, +) from opentelemetry import trace as trace_api from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource @@ -78,5 +81,54 @@ def test_invoke_client(in_memory_span_exporter: InMemorySpanExporter) -> None: assert span.status.is_ok attributes = dict(span.attributes or dict()) assert attributes["llm.model_name"] == "anthropic.claude-v2" - assert attributes["llm.prompts"] == "Human: hello there? Assistant:" - assert attributes["message.content"] == " Hello!" + assert attributes["input.value"] == "Human: hello there? Assistant:" + assert attributes["output.value"] == " Hello!" + assert attributes["llm.token_count.prompt"] == 12 + assert attributes["llm.token_count.completion"] == 6 + assert attributes["llm.token_count.total"] == 18 + assert attributes["openinference.span.kind"] == OpenInferenceSpanKindValues.LLM.value + + +def test_invoke_client_with_missing_tokens(in_memory_span_exporter: InMemorySpanExporter) -> None: + output = b'{"completion":" Hello!","stop_reason":"stop_sequence","stop":"\\n\\nHuman:"}' + streaming_body = StreamingBody(io.BytesIO(output), len(output)) + mock_response = { + "ResponseMetadata": { + "RequestId": "xxxxxxxx-yyyy-zzzz-1234-abcdefghijklmno", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sun, 21 Jan 2024 20:00:00 GMT", + "content-type": "application/json", + "content-length": "74", + "connection": "keep-alive", + "x-amzn-requestid": "xxxxxxxx-yyyy-zzzz-1234-abcdefghijklmno", + "x-amzn-bedrock-invocation-latency": "425", + "x-amzn-bedrock-output-token-count": "6", + }, + "RetryAttempts": 0, + }, + "contentType": "application/json", + "body": streaming_body, + } + session = boto3.session.Session() + client = session.client("bedrock-runtime", region_name="us-east-1") + + # instead of mocking the HTTP response, we mock the boto client method directly to avoid + # complexities with mocking auth + client._unwrapped_invoke_model = MagicMock(return_value=mock_response) + client.invoke_model( + modelId="anthropic.claude-v2", + body=b'{"prompt": "Human: hello there? Assistant:", "max_tokens_to_sample": 1024}', + ) + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.status.is_ok + attributes = dict(span.attributes or dict()) + assert attributes["llm.model_name"] == "anthropic.claude-v2" + assert attributes["input.value"] == "Human: hello there? Assistant:" + assert attributes["output.value"] == " Hello!" + assert "llm.token_count.prompt" not in attributes + assert attributes["llm.token_count.completion"] == 6 + assert "llm.token_count.total" not in attributes + assert attributes["openinference.span.kind"] == OpenInferenceSpanKindValues.LLM.value