Skip to content

Commit

Permalink
avoid reflection for directly returning tools discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Dec 19, 2024
1 parent 4f4f1ad commit 8f3cddc
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@
*
* @return whether to return the result directly
*/
boolean returnDirectly() default false;
boolean directReturn() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.service.TypeUtils.isResultRawString;
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
import static dev.langchain4j.service.output.JsonSchemas.jsonSchemaFrom;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
Expand Down Expand Up @@ -145,6 +146,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
// TODO give user ability to provide custom OutputParser
Type returnType = method.getGenericReturnType();
boolean isReturnTypeRaw = typeHasRawClass(returnType, Result.class);
boolean isResultRawString = isReturnTypeRaw && isResultRawString(returnType);
boolean directReturnFromTool = isResultRawString;

boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType);

Expand Down Expand Up @@ -241,7 +244,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
List<ToolExecution> toolExecutions = new ArrayList<>();
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>();

while (true) {

if (executionsLeft-- == 0) {
Expand All @@ -262,11 +265,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
break;
}

// only return directly if the returntype is Result<String>
// only return directly if the return type is Result<String>
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
toolExecutionRequests.add(toolExecutionRequest);
toolExecutions.add(ToolExecution.builder()
.request(toolExecutionRequest)
.result(toolExecutionResult)
Expand All @@ -287,8 +289,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
messages = context.chatMemory(memoryId).messages();
}
// it's possible that an ai message only has 1 tool request, but then the subsequent ai message within the while loop has a different tool request, so only if all toolrequests are return direct do we return directly
boolean shouldReturnDirectly = isReturnTypeRaw && isResultRawString(returnType) && allToolsReturnDirectly(toolExecutionRequests, toolExecutors);
if (shouldReturnDirectly) {
directReturnFromTool = directReturnFromTool && allToolsReturnDirectly(aiMessage.toolExecutionRequests(), toolExecutors);
if (directReturnFromTool) {
return new Result<T>(tokenUsageAccumulator, Collections.emptyList(), response.finishReason(), toolExecutions);
}

Expand All @@ -312,37 +314,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
}
}

private boolean isToolReturnDirectly(String toolName, Map<String, ToolExecutor> toolExecutors) {
ToolExecutor executor = toolExecutors.get(toolName);
try {
Method method = executor.getClass().getMethod("isReturnDirectly");
return (boolean) method.invoke(executor);
} catch (Exception e) {
return false;
}
}

private boolean allToolsReturnDirectly(List<ToolExecutionRequest> requests, Map<String, ToolExecutor> toolExecutors) {
if (requests == null || requests.isEmpty()) {
return false;
}

for (ToolExecutionRequest request : requests) {
if (!isToolReturnDirectly(request.name(), toolExecutors)) {
return false;
}
}
return true;
}

private boolean isResultRawString(Type returnType) {
if (!(returnType instanceof ParameterizedType paramType)) {
return false;
}
return Arrays.stream(paramType.getActualTypeArguments())
.findFirst()
.map(String.class::equals)
.orElse(false);
return requests.stream().map(r -> toolExecutors.get(r.name())).allMatch(tExec -> tExec != null && tExec.isDirectReturn());
}

private boolean canAdaptTokenStreamTo(Type returnType) {
Expand Down
11 changes: 11 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/service/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.lang.reflect.TypeVariable;
import java.lang.reflect.WildcardType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -35,6 +36,16 @@ public static boolean typeHasRawClass(Type type, Class<?> rawClass) {
return rawClass.equals(getRawClass(type));
}

public static boolean isResultRawString(Type returnType) {
if (!(returnType instanceof ParameterizedType paramType)) {
return false;
}
return Arrays.stream(paramType.getActualTypeArguments())
.findFirst()
.map(String.class::equals)
.orElse(false);
}

public static Class<?> resolveFirstGenericParameterClass(Type returnType) {
Type[] typeArguments = getTypeArguments(returnType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ public DefaultToolExecutor(Object object, ToolExecutionRequest toolExecutionRequ
this.method = findMethod(object, toolExecutionRequest);
}

public boolean isReturnDirectly() {
@Override
public boolean isDirectReturn() {
Tool toolAnnotation = method.getAnnotation(Tool.class);
return toolAnnotation != null && toolAnnotation.returnDirectly();
return toolAnnotation != null && toolAnnotation.directReturn();
}

Method findMethod(Object object, ToolExecutionRequest toolExecutionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,14 @@ public interface ToolExecutor {
* @return The result of the tool execution.
*/
String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId);

/**
* Returns true if the result of the tool invocation can be returned directly as it is,
* without any further processing from the LLM.
*
* @return True if the tool invocation result can be directly returned as it is.
*/
default boolean isDirectReturn() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ interface AssistantResultString {

static class ToolWithPrimitiveParametersReturnDirectly {

@Tool(returnDirectly = true)
@Tool(directReturn = true)
int add(int a, int b) {
return a + b;
}
Expand Down Expand Up @@ -940,7 +940,7 @@ protected void should_execute_tools_all_return_direct_true_multiple_called_at_a_


static class FirstToolReturnDirectFalse {
@Tool(returnDirectly = false)
@Tool(directReturn = false)
int add(int a, int b) {
return a + b;
}
Expand All @@ -953,7 +953,7 @@ int add(int a, int b) {
}

static class FirstToolReturnDirectTrue {
@Tool(returnDirectly = true)
@Tool(directReturn = true)
int add(int a, int b) {
return a + b;
}
Expand All @@ -966,7 +966,7 @@ int add(int a, int b) {
}

static class SecondToolReturnDirectTrue {
@Tool(returnDirectly = true)
@Tool(directReturn = true)
int multiply(int a, int b) {
return a * b;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public int addOne(int num) {

private static class TestToolReturnDirectly {

@Tool(returnDirectly = true)
@Tool(directReturn = true)
public int addOne(int num) {
return num + 1;
}
Expand Down Expand Up @@ -322,7 +322,7 @@ public void test_get_return_directly_true() {

DefaultToolExecutor toolExecutor = new DefaultToolExecutor(new TestToolReturnDirectly(), request);

assertThat(toolExecutor.isReturnDirectly()).isTrue();
assertThat(toolExecutor.isDirectReturn()).isTrue();
}

@Test
Expand All @@ -335,7 +335,7 @@ public void test_get_return_directly_false() {

DefaultToolExecutor toolExecutor = new DefaultToolExecutor(new TestTool(), request);

assertThat(toolExecutor.isReturnDirectly()).isFalse();
assertThat(toolExecutor.isDirectReturn()).isFalse();
}

@Test
Expand Down

0 comments on commit 8f3cddc

Please sign in to comment.