Skip to content

Commit

Permalink
fix: Add token counts to bedrock instrumentation (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator authored Mar 12, 2024
1 parent 1491f41 commit 29b0ac9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from botocore.response import StreamingBody
from openinference.instrumentation.bedrock.package import _instruments
from openinference.instrumentation.bedrock.version import __version__
from openinference.semconv.trace import MessageAttributes, SpanAttributes
from openinference.semconv.trace import (
OpenInferenceSpanKindValues,
SpanAttributes,
)
from opentelemetry import context as context_api
from opentelemetry import trace as trace_api
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
Expand Down Expand Up @@ -83,7 +86,14 @@ def _invocation_wrapper(wrapped_client: InstrumentedClient) -> Callable[..., Any

@wraps(wrapped_client.invoke_model)
def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]:
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped_client._unwrapped_invoke_model(*args, **kwargs) # type: ignore

with tracer.start_as_current_span("bedrock.invoke_model") as span:
span.set_attribute(
SpanAttributes.OPENINFERENCE_SPAN_KIND,
OpenInferenceSpanKindValues.LLM.value,
)
response = wrapped_client._unwrapped_invoke_model(*args, **kwargs)
response["body"] = BufferedStreamingBody(
response["body"]._raw_stream, response["body"]._content_length
Expand All @@ -96,11 +106,34 @@ def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]:
prompt = request_body.pop("prompt")
invocation_parameters = json.dumps(request_body)

_set_span_attribute(span, SpanAttributes.LLM_PROMPTS, prompt)
_set_span_attribute(span, SpanAttributes.INPUT_VALUE, prompt)
_set_span_attribute(
span, SpanAttributes.LLM_INVOCATION_PARAMETERS, invocation_parameters
)

if metadata := response.get("ResponseMetadata"):
if headers := metadata.get("HTTPHeaders"):
if input_token_count := headers.get("x-amzn-bedrock-input-token-count"):
input_token_count = int(input_token_count)
_set_span_attribute(
span, SpanAttributes.LLM_TOKEN_COUNT_PROMPT, input_token_count
)
if response_token_count := headers.get("x-amzn-bedrock-output-token-count"):
response_token_count = int(response_token_count)
_set_span_attribute(
span,
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
response_token_count,
)
if total_token_count := (
input_token_count + response_token_count
if input_token_count and response_token_count
else None
):
_set_span_attribute(
span, SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_token_count
)

if model_id := kwargs.get("modelId"):
_set_span_attribute(span, SpanAttributes.LLM_MODEL_NAME, model_id)

Expand All @@ -117,7 +150,7 @@ def instrumented_response(*args: Any, **kwargs: Any) -> Dict[str, Any]:
content = ""

if content:
_set_span_attribute(span, MessageAttributes.MESSAGE_CONTENT, content)
_set_span_attribute(span, SpanAttributes.OUTPUT_VALUE, content)

return response # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import pytest
from botocore.response import StreamingBody
from openinference.instrumentation.bedrock import BedrockInstrumentor
from openinference.semconv.trace import (
OpenInferenceSpanKindValues,
)
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
Expand Down Expand Up @@ -78,5 +81,54 @@ def test_invoke_client(in_memory_span_exporter: InMemorySpanExporter) -> None:
assert span.status.is_ok
attributes = dict(span.attributes or dict())
assert attributes["llm.model_name"] == "anthropic.claude-v2"
assert attributes["llm.prompts"] == "Human: hello there? Assistant:"
assert attributes["message.content"] == " Hello!"
assert attributes["input.value"] == "Human: hello there? Assistant:"
assert attributes["output.value"] == " Hello!"
assert attributes["llm.token_count.prompt"] == 12
assert attributes["llm.token_count.completion"] == 6
assert attributes["llm.token_count.total"] == 18
assert attributes["openinference.span.kind"] == OpenInferenceSpanKindValues.LLM.value


def test_invoke_client_with_missing_tokens(in_memory_span_exporter: InMemorySpanExporter) -> None:
output = b'{"completion":" Hello!","stop_reason":"stop_sequence","stop":"\\n\\nHuman:"}'
streaming_body = StreamingBody(io.BytesIO(output), len(output))
mock_response = {
"ResponseMetadata": {
"RequestId": "xxxxxxxx-yyyy-zzzz-1234-abcdefghijklmno",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sun, 21 Jan 2024 20:00:00 GMT",
"content-type": "application/json",
"content-length": "74",
"connection": "keep-alive",
"x-amzn-requestid": "xxxxxxxx-yyyy-zzzz-1234-abcdefghijklmno",
"x-amzn-bedrock-invocation-latency": "425",
"x-amzn-bedrock-output-token-count": "6",
},
"RetryAttempts": 0,
},
"contentType": "application/json",
"body": streaming_body,
}
session = boto3.session.Session()
client = session.client("bedrock-runtime", region_name="us-east-1")

# instead of mocking the HTTP response, we mock the boto client method directly to avoid
# complexities with mocking auth
client._unwrapped_invoke_model = MagicMock(return_value=mock_response)
client.invoke_model(
modelId="anthropic.claude-v2",
body=b'{"prompt": "Human: hello there? Assistant:", "max_tokens_to_sample": 1024}',
)
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.status.is_ok
attributes = dict(span.attributes or dict())
assert attributes["llm.model_name"] == "anthropic.claude-v2"
assert attributes["input.value"] == "Human: hello there? Assistant:"
assert attributes["output.value"] == " Hello!"
assert "llm.token_count.prompt" not in attributes
assert attributes["llm.token_count.completion"] == 6
assert "llm.token_count.total" not in attributes
assert attributes["openinference.span.kind"] == OpenInferenceSpanKindValues.LLM.value

0 comments on commit 29b0ac9

Please sign in to comment.