From ab2ef4655c16e662c499b0302f4a0b28892f6b6c Mon Sep 17 00:00:00 2001 From: Kiko Castillo Date: Mon, 6 May 2024 16:06:44 -0700 Subject: [PATCH] fix: Add missing context variables to list (#438) --- .../openinference/instrumentation/__init__.py | 3 ++ .../tests/test_context_managers.py | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/openinference-instrumentation/src/openinference/instrumentation/__init__.py b/python/openinference-instrumentation/src/openinference/instrumentation/__init__.py index 6d459f580..78318f8f5 100644 --- a/python/openinference-instrumentation/src/openinference/instrumentation/__init__.py +++ b/python/openinference-instrumentation/src/openinference/instrumentation/__init__.py @@ -19,6 +19,9 @@ SpanAttributes.USER_ID, SpanAttributes.METADATA, SpanAttributes.TAG_TAGS, + SpanAttributes.LLM_PROMPT_TEMPLATE, + SpanAttributes.LLM_PROMPT_TEMPLATE_VERSION, + SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES, ) diff --git a/python/openinference-instrumentation/tests/test_context_managers.py b/python/openinference-instrumentation/tests/test_context_managers.py index cff14a8ba..30792d3cb 100644 --- a/python/openinference-instrumentation/tests/test_context_managers.py +++ b/python/openinference-instrumentation/tests/test_context_managers.py @@ -3,6 +3,7 @@ import pytest from openinference.instrumentation import ( + get_attributes_from_context, suppress_tracing, using_attributes, using_metadata, @@ -14,6 +15,7 @@ from openinference.semconv.trace import SpanAttributes from opentelemetry.context import ( _SUPPRESS_INSTRUMENTATION_KEY, + get_current, get_value, ) @@ -198,6 +200,34 @@ def f() -> None: assert get_value(SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES) is None +def test_get_attributes_from_context( + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], +) -> None: + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + ctx = get_current() + context_vars = {attr[0]: attr[1] for attr in get_attributes_from_context()} + assert len(ctx) == len(context_vars) + for key, value in ctx.items(): + assert context_vars.pop(key, None) == value, f"Missing context variable {key}" + + context_vars = {attr[0]: attr[1] for attr in get_attributes_from_context()} + assert context_vars == {} + + @pytest.fixture def session_id() -> str: return "test-session"