From 88bd5b4949ba9b46ece5f55eac45c5e610376cda Mon Sep 17 00:00:00 2001 From: TrungBui59 Date: Wed, 22 Nov 2023 17:38:31 -0500 Subject: [PATCH] Fixing style and test Signed-off-by: TrungBui59 --- .../remote/HttpJsonConnectorExecutorTest.java | 214 ++++-------------- 1 file changed, 45 insertions(+), 169 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index d248b89f24..e91110b74e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -5,20 +5,16 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; import org.apache.http.HttpEntity; -import org.apache.http.client.ClientProtocolException; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; @@ -29,11 +25,9 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.ingest.TestTemplateService; -import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -43,7 +37,6 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.script.ScriptService; @@ -54,13 +47,13 @@ public class HttpJsonConnectorExecutorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Andrea Mock + @Mock ScriptService scriptService; - @Andrea Mock + @Mock CloseableHttpClient httpClient; - @Andrea Mock + @Mock CloseableHttpResponse response; @Before @@ -68,162 +61,44 @@ public void setUp() { MockitoAnnotations.openMocks(this); } - public void invokeRemoteModelSuccessPath(String httpMethod) { - try { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method(httpMethod) - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); - when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); - when(response.getEntity()).thenReturn(entity); - Connector connector = HttpConnector - .builder().name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); - when(executor.getHttpClient()).thenReturn(httpClient); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - verify(executor, Mockito.times(1)).invokeRemoteModel(any(), any(), any(), any()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); - Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - } catch (Exception e) { - fail("An exception was thrown: " + e.getMessage()); - } - } - - public void invokeRemoteModelHeaderTest(Map header) throws ClientProtocolException, IOException { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("GET") - .headers(header) - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); - Map credential = new HashMap<>(); - credential.put("key", "test_key_value"); - when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); - when(response.getEntity()).thenReturn(entity); - HttpConnector connector = spy(HttpConnector.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .credential(credential) - .actions(Arrays.asList(predictAction)) - .backendRoles(Arrays.asList("role1", "role2")) - .accessMode(AccessMode.PUBLIC) - .build()); - Function decryptFunction = 👎 -> (n); - connector.decrypt(decryptFunction); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); - when(executor.getHttpClient()).thenReturn(httpClient); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict( - MLInput.builder().algorithm(FunctionName.REMOTE) - .inputDataset(inputDataSet) - .build() - ); - verify(connector, Mockito.times(1)).getDecryptedHeaders(); - Assert.assertEquals(header, executor.getConnector().getDecryptedHeaders()); - } - @Test - public void invokeRemoteModelGetMethodSuccessPath() { - invokeRemoteModel_SuccessPath("GET"); - } - @Test - public void invokeRemoteModelPostMethodSuccessPath() { - invokeRemoteModel_SuccessPath("POST"); - } - - @Test - public void invokeRemoteModelHeaderNull() throws ClientProtocolException, IOException { - invokeRemoteModel_HeaderTest(null); - } - - public void invokeRemoteModelHeaderNotNull() throws ClientProtocolException, IOException { - Map headers = new HashMap<>(); - headers.put("api_key", "${credential.key}"); - headers.put("Content-type", "application/json"); - invokeRemoteModel_HeaderTest(headers); - } - - @Test - public void invokeRemoteModelPostMethodErrorPath() { - exceptionRule.expect(MLException.class); - exceptionRule.expectMessage("Failed to create http request for remote model"); - - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("post") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector.builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(null, null, null, null); - } - - @Test - public void invokeRemoteModelGetMethodErrorPath() { - exceptionRule.expect(MLException.class); - exceptionRule.expectMessage("Failed to create http request for remote model"); - - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("get") - .url("wrong url") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector.builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(null, null, null, null); - } - @Test - public void invokeRemoteModelWrongHttpMethod() { + public void invokeRemoteModel_WrongHttpMethod() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("unsupported http method"); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("wrong_method") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector.builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("wrong_method") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @Test - public void executePredictRemoteInferenceInput() throws IOException { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + public void executePredict_RemoteInferenceInput() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); @@ -245,13 +120,14 @@ public void executePredictRemoteInferenceInput() throws IOException { } @Test - public void executePredictTextDocsInputNoPreprocessFunction() throws IOException { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); @@ -309,7 +185,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { } @Test - public void executePredictTextDocsInput() throws IOException { + public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; when(scriptService.compile(any(), any())) @@ -385,4 +261,4 @@ public void executePredictTextDocsInput() throws IOException { modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData() ); } -} \ No newline at end of file +}