diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 772fcf964d..d248b89f24 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -5,16 +5,20 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; import org.apache.http.HttpEntity; -import org.apache.http.ProtocolVersion; -import org.apache.http.StatusLine; +import org.apache.http.client.ClientProtocolException; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; @@ -25,9 +29,11 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -48,13 +54,13 @@ public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Mock + @Andrea Mock ScriptService scriptService; - @Mock + @Andrea Mock CloseableHttpClient httpClient; - @Mock + @Andrea Mock CloseableHttpResponse response; @Before @@ -62,24 +68,114 @@ public void setUp() { MockitoAnnotations.openMocks(this); } + public void invokeRemoteModelSuccessPath(String httpMethod) { + try { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method(httpMethod) + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); + Connector connector = HttpConnector + .builder().name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + verify(executor, Mockito.times(1)).invokeRemoteModel(any(), any(), any(), any()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + } catch (Exception e) { + fail("An exception was thrown: " + e.getMessage()); + } + } + + public void invokeRemoteModelHeaderTest(Map header) throws ClientProtocolException, IOException { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("GET") + .headers(header) + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map credential = new HashMap<>(); + credential.put("key", "test_key_value"); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); + HttpConnector connector = spy(HttpConnector.builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .credential(credential) + .actions(Arrays.asList(predictAction)) + .backendRoles(Arrays.asList("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .build()); + Function decryptFunction = 👎 -> (n); + connector.decrypt(decryptFunction); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict( + MLInput.builder().algorithm(FunctionName.REMOTE) + .inputDataset(inputDataSet) + .build() + ); + verify(connector, Mockito.times(1)).getDecryptedHeaders(); + Assert.assertEquals(header, executor.getConnector().getDecryptedHeaders()); + } + @Test + public void invokeRemoteModelGetMethodSuccessPath() { + invokeRemoteModel_SuccessPath("GET"); + } @Test - public void invokeRemoteModel_POSTMethodErrorPath() { + public void invokeRemoteModelPostMethodSuccessPath() { + invokeRemoteModel_SuccessPath("POST"); + } + + @Test + public void invokeRemoteModelHeaderNull() throws ClientProtocolException, IOException { + invokeRemoteModel_HeaderTest(null); + } + + public void invokeRemoteModelHeaderNotNull() throws ClientProtocolException, IOException { + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + headers.put("Content-type", "application/json"); + invokeRemoteModel_HeaderTest(headers); + } + + @Test + public void invokeRemoteModelPostMethodErrorPath() { exceptionRule.expect(MLException.class); exceptionRule.expectMessage("Failed to create http request for remote model"); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("post") - .url("wrong url") + .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void invokeRemoteModel_GETMethodErrorPath() { + public void invokeRemoteModelGetMethodErrorPath() { exceptionRule.expect(MLException.class); exceptionRule.expectMessage("Failed to create http request for remote model"); @@ -89,49 +185,45 @@ public void invokeRemoteModel_GETMethodErrorPath() { .url("wrong url") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void invokeRemoteModel_WrongHttpMethod() { + public void invokeRemoteModelWrongHttpMethod() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("unsupported http method"); - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("wrong_method") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("wrong_method") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void executePredict_RemoteInferenceInput() throws IOException { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); + public void executePredictRemoteInferenceInput() throws IOException { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); @@ -153,14 +245,13 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); + public void executePredictTextDocsInputNoPreprocessFunction() throws IOException { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); @@ -218,7 +309,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { } @Test - public void executePredict_TextDocsInput() throws IOException { + public void executePredictTextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; when(scriptService.compile(any(), any())) @@ -294,4 +385,4 @@ public void executePredict_TextDocsInput() throws IOException { modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData() ); } -} +} \ No newline at end of file