Skip to content

Commit

Permalink
Adding test for execption with GET for invokeModel
Browse files Browse the repository at this point in the history
Signed-off-by: TrungBui59 <[email protected]>
  • Loading branch information
TrungBui59 committed Nov 22, 2023
1 parent 18af4a8 commit ffbfee0
Showing 1 changed file with 141 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
Expand All @@ -25,9 +29,11 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
Expand All @@ -48,38 +54,128 @@ public class HttpJsonConnectorExecutorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Mock
@Andrea Mock
ScriptService scriptService;

@Mock
@Andrea Mock
CloseableHttpClient httpClient;

@Mock
@Andrea Mock
CloseableHttpResponse response;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
}

public void invokeRemoteModelSuccessPath(String httpMethod) {
try {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method(httpMethod)
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = HttpConnector
.builder().name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
verify(executor, Mockito.times(1)).invokeRemoteModel(any(), any(), any(), any());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
} catch (Exception e) {
fail("An exception was thrown: " + e.getMessage());
}
}

public void invokeRemoteModelHeaderTest(Map<String, String> header) throws ClientProtocolException, IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("GET")
.headers(header)
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
Map<String, String> credential = new HashMap<>();
credential.put("key", "test_key_value");
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
HttpConnector connector = spy(HttpConnector.builder()
.name("test_connector_name")
.description("this is a test connector")
.version("1")
.protocol("http")
.credential(credential)
.actions(Arrays.asList(predictAction))
.backendRoles(Arrays.asList("role1", "role2"))
.accessMode(AccessMode.PUBLIC)
.build());
Function<String, String> decryptFunction = 👎 -> (n);
connector.decrypt(decryptFunction);
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(
MLInput.builder().algorithm(FunctionName.REMOTE)
.inputDataset(inputDataSet)
.build()
);
verify(connector, Mockito.times(1)).getDecryptedHeaders();
Assert.assertEquals(header, executor.getConnector().getDecryptedHeaders());
}
@Test
public void invokeRemoteModelGetMethodSuccessPath() {
invokeRemoteModel_SuccessPath("GET");
}
@Test
public void invokeRemoteModel_POSTMethodErrorPath() {
public void invokeRemoteModelPostMethodSuccessPath() {
invokeRemoteModel_SuccessPath("POST");
}

@Test
public void invokeRemoteModelHeaderNull() throws ClientProtocolException, IOException {
invokeRemoteModel_HeaderTest(null);
}

public void invokeRemoteModelHeaderNotNull() throws ClientProtocolException, IOException {
Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");
headers.put("Content-type", "application/json");
invokeRemoteModel_HeaderTest(headers);
}

@Test
public void invokeRemoteModelPostMethodErrorPath() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("Failed to create http request for remote model");

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("post")
.url("wrong url")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}

@Test
public void invokeRemoteModel_GETMethodErrorPath() {
public void invokeRemoteModelGetMethodErrorPath() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("Failed to create http request for remote model");

Expand All @@ -89,49 +185,45 @@ public void invokeRemoteModel_GETMethodErrorPath() {
.url("wrong url")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}

@Test
public void invokeRemoteModel_WrongHttpMethod() {
public void invokeRemoteModelWrongHttpMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("unsupported http method");
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("wrong_method")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("wrong_method")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}

@Test
public void executePredict_RemoteInferenceInput() throws IOException {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
public void executePredictRemoteInferenceInput() throws IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
Expand All @@ -153,14 +245,13 @@ public void executePredict_RemoteInferenceInput() throws IOException {
}

@Test
public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
public void executePredictTextDocsInputNoPreprocessFunction() throws IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Expand Down Expand Up @@ -218,7 +309,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException {
}

@Test
public void executePredict_TextDocsInput() throws IOException {
public void executePredictTextDocsInput() throws IOException {
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
when(scriptService.compile(any(), any()))
Expand Down Expand Up @@ -294,4 +385,4 @@ public void executePredict_TextDocsInput() throws IOException {
modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()
);
}
}
}

0 comments on commit ffbfee0

Please sign in to comment.