Skip to content

Commit

Permalink
improve method names
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 19, 2025
1 parent 9346ff5 commit e4b2a15
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _cast_value(
]


def get_span_kind(kind: OpenInferenceSpanKind) -> Dict[str, AttributeValue]:
def get_span_kind(kind: OpenInferenceSpanKind, /) -> Dict[str, AttributeValue]:
normalized_kind = _normalize_openinference_span_kind(kind)
return {
OPENINFERENCE_SPAN_KIND: normalized_kind.value,
Expand All @@ -406,6 +406,7 @@ def get_span_kind(kind: OpenInferenceSpanKind) -> Dict[str, AttributeValue]:

def get_input_value_and_mime_type(
value: Any,
*,
mime_type: Optional[OpenInferenceMimeType] = None,
) -> Dict[str, AttributeValue]:
normalized_mime_type: Optional[OpenInferenceMimeTypeValues] = None
Expand All @@ -428,6 +429,7 @@ def get_input_value_and_mime_type(

def get_output_value_and_mime_type(
value: Any,
*,
mime_type: Optional[OpenInferenceMimeType] = None,
) -> Dict[str, AttributeValue]:
normalized_mime_type: Optional[OpenInferenceMimeTypeValues] = None
Expand Down Expand Up @@ -546,14 +548,14 @@ def set_input(
value: Any,
mime_type: Optional[OpenInferenceMimeType] = None,
) -> None:
self.set_attributes(get_input_value_and_mime_type(value, mime_type))
self.set_attributes(get_input_value_and_mime_type(value, mime_type=mime_type))

def set_output(
self,
value: Any,
mime_type: Optional[OpenInferenceMimeType] = None,
) -> None:
self.set_attributes(get_output_value_and_mime_type(value, mime_type))
self.set_attributes(get_output_value_and_mime_type(value, mime_type=mime_type))


class AgentSpan(OpenInferenceSpan):
Expand Down Expand Up @@ -1054,16 +1056,16 @@ def _chain_context(
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Iterator[_ChainContext]:
span_name = name or wrapped.__name__
span_name = name or _infer_span_name(class_instance=instance, callable=wrapped)
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments

if len(arguments) == 1:
argument = next(iter(arguments.values()))
input_attributes = get_input_value_and_mime_type(value=argument)
input_attributes = get_input_value_and_mime_type(argument)
else:
input_attributes = get_input_value_and_mime_type(value=arguments)
input_attributes = get_input_value_and_mime_type(arguments)

with tracer.start_as_current_span(
span_name,
Expand Down Expand Up @@ -1098,20 +1100,14 @@ def _tool_context(
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Iterator[_ToolContext]:
span_name = name or wrapped.__name__
span_name = name or _infer_span_name(class_instance=instance, callable=wrapped)
bound_args = inspect.signature(wrapped).bind(*args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
input_attributes = get_input_value_and_mime_type(value=arguments)
tool_description: Optional[str] = description
if (
not tool_description
and (docstring := wrapped.__doc__) is not None
and (stripped_docstring := docstring.strip())
):
tool_description = stripped_docstring
input_attributes = get_input_value_and_mime_type(arguments)
tool_description = description or _infer_tool_description_from_docstring(wrapped.__doc__)
tool_attributes = get_tool_attributes(
name=name or wrapped.__name__,
name=span_name,
description=tool_description,
parameters={},
)
Expand All @@ -1128,6 +1124,22 @@ def _tool_context(
span.set_status(Status(StatusCode.OK))


def _infer_span_name(*, class_instance: Any, callable: Callable[..., Any]) -> str:
is_method = class_instance is not None
if is_method:
class_name = class_instance.__class__.__name__
method_name = callable.__name__
return f"{class_name}.{method_name}"
function_name = callable.__name__
return function_name


def _infer_tool_description_from_docstring(docstring: Optional[str]) -> Optional[str]:
if docstring is not None and (stripped_docstring := docstring.strip()):
return stripped_docstring
return None


# span attributes
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def chain_with_error(input: str) -> str:
assert attributes.pop(INPUT_VALUE) == "input"
assert not attributes

def test_class_method(
def test_method(
self,
in_memory_span_exporter: InMemorySpanExporter,
tracer: OITracer,
Expand All @@ -664,7 +664,7 @@ def decorated_chain_method(self, input1: str, input2: str) -> str:
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "decorated_chain_method"
assert span.name == "ChainRunner.decorated_chain_method"
assert span.status.is_ok
assert not span.events
attributes = dict(span.attributes or {})
Expand Down Expand Up @@ -855,6 +855,41 @@ def decorated_tool(input1: str, input2: int) -> None:
assert json.loads(tool_parameters) == {}
assert not attributes

def test_class_tool_with_call_method(
self,
in_memory_span_exporter: InMemorySpanExporter,
tracer: OITracer,
) -> None:
class ClassTool:
@tracer.tool
def __call__(self, input: str) -> None:
"""
tool-description
"""
pass

callable_instance = ClassTool()
callable_instance("input")

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "ClassTool.__call__"
assert span.status.is_ok
assert not span.events
attributes = dict(span.attributes or {})
assert attributes.pop(OPENINFERENCE_SPAN_KIND) == TOOL
assert attributes.pop(INPUT_MIME_TYPE) == JSON
assert isinstance(input_value := attributes.pop(INPUT_VALUE), str)
assert json.loads(input_value) == {"input": "input"}
assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT
assert attributes.pop(OUTPUT_VALUE) == "None"
assert attributes.pop(TOOL_NAME) == "ClassTool.__call__"
assert attributes.pop(TOOL_DESCRIPTION) == "tool-description"
assert isinstance(tool_parameters := attributes.pop(TOOL_PARAMETERS), str)
assert json.loads(tool_parameters) == {}
assert not attributes

async def test_async_tool(
self,
in_memory_span_exporter: InMemorySpanExporter,
Expand Down

0 comments on commit e4b2a15

Please sign in to comment.