diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 641cb2f336..1df40d1f88 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -5,8 +5,20 @@ package org.opensearch.ml.engine.algorithms.remote; -import com.google.common.collect.ImmutableMap; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + import org.apache.http.HttpEntity; +import org.apache.http.client.ClientProtocolException; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; @@ -16,8 +28,10 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -30,29 +44,21 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; -import java.io.IOException; -import java.util.Arrays; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableMap; public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Mock + @Andrea Mock ScriptService scriptService; - @Mock + @Andrea Mock CloseableHttpClient httpClient; - @Mock + @Andrea Mock CloseableHttpResponse response; @Before @@ -60,24 +66,114 @@ public void setUp() { MockitoAnnotations.openMocks(this); } + public void invokeRemoteModelSuccessPath(String httpMethod) { + try { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method(httpMethod) + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); + Connector connector = HttpConnector + .builder().name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + verify(executor, Mockito.times(1)).invokeRemoteModel(any(), any(), any(), any()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + } catch (Exception e) { + fail("An exception was thrown: " + e.getMessage()); + } + } + + public void invokeRemoteModelHeaderTest(Map header) throws ClientProtocolException, IOException { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("GET") + .headers(header) + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map credential = new HashMap<>(); + credential.put("key", "test_key_value"); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); + HttpConnector connector = spy(HttpConnector.builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .credential(credential) + .actions(Arrays.asList(predictAction)) + .backendRoles(Arrays.asList("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .build()); + Function decryptFunction = 👎 -> (n); + connector.decrypt(decryptFunction); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict( + MLInput.builder().algorithm(FunctionName.REMOTE) + .inputDataset(inputDataSet) + .build() + ); + verify(connector, Mockito.times(1)).getDecryptedHeaders(); + Assert.assertEquals(header, executor.getConnector().getDecryptedHeaders()); + } + @Test + public void invokeRemoteModelGetMethodSuccessPath() { + invokeRemoteModel_SuccessPath("GET"); + } + @Test + public void invokeRemoteModelPostMethodSuccessPath() { + invokeRemoteModel_SuccessPath("POST"); + } + @Test - public void invokeRemoteModel_POSTMethodErrorPath() { + public void invokeRemoteModelHeaderNull() throws ClientProtocolException, IOException { + invokeRemoteModel_HeaderTest(null); + } + + public void invokeRemoteModelHeaderNotNull() throws ClientProtocolException, IOException { + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + headers.put("Content-type", "application/json"); + invokeRemoteModel_HeaderTest(headers); + } + + @Test + public void invokeRemoteModelPostMethodErrorPath() { exceptionRule.expect(MLException.class); exceptionRule.expectMessage("Failed to create http request for remote model"); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("post") - .url("wrong url") + .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void invokeRemoteModel_GETMethodErrorPath() { + public void invokeRemoteModelGetMethodErrorPath() { exceptionRule.expect(MLException.class); exceptionRule.expectMessage("Failed to create http request for remote model"); @@ -87,13 +183,18 @@ public void invokeRemoteModel_GETMethodErrorPath() { .url("wrong url") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void invokeRemoteModel_WrongHttpMethod() { + public void invokeRemoteModelWrongHttpMethod() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("unsupported http method"); ConnectorAction predictAction = ConnectorAction.builder() @@ -102,13 +203,18 @@ public void invokeRemoteModel_WrongHttpMethod() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void executePredict_RemoteInferenceInput() throws IOException { + public void executePredictRemoteInferenceInput() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") @@ -130,7 +236,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { + public void executePredictTextDocsInputNoPreprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") @@ -152,7 +258,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti } @Test - public void executePredict_TextDocsInput() throws IOException { + public void executePredictTextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; when(scriptService.compile(any(), any())) @@ -189,4 +295,4 @@ public void executePredict_TextDocsInput() throws IOException { Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } -} +} \ No newline at end of file