Skip to content

Commit

Permalink
Split the tests
Browse files Browse the repository at this point in the history
Signed-off-by: TrungBui59 <[email protected]>
  • Loading branch information
TrungBui59 committed Nov 6, 2023
1 parent 0213140 commit 6a1e3ca
Showing 1 changed file with 75 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
Expand All @@ -14,6 +13,8 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.apache.http.client.ClientProtocolException;
Expand Down Expand Up @@ -63,19 +64,23 @@ public void setUp() {
MockitoAnnotations.openMocks(this);
}

@Test
public void invokeRemoteModel_GetMethodSuccessPath() throws ClientProtocolException, IOException {
public void invokeRemoteModel_SuccessPath(String httpMethod) {
try {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("get")
.method(httpMethod)
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector
.builder().name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
Expand All @@ -87,29 +92,54 @@ public void invokeRemoteModel_GetMethodSuccessPath() throws ClientProtocolExcept
fail("An exception was thrown: " + e.getMessage());
}
}

public void invokeRemoteModel_HeaderTest(Map<String, String> header) throws ClientProtocolException, IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("GET")
.headers(header)
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = spy(HttpConnector
.builder().name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build());
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(
MLInput.builder().algorithm(FunctionName.REMOTE)
.inputDataset(inputDataSet)
.build()
);
// verify(connector, Mockito.times(1)).getDecryptedHeaders();
Assert.assertEquals(header, executor.getConnector().getDecryptedHeaders());
}
@Test
public void invokeRemoteModel_POSTMethodSuccessPath() throws ClientProtocolException, IOException {
try {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
verify(executor, Mockito.times(1)).invokeRemoteModel(any(), any(), any(), any());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
} catch (Exception e) {
fail("An exception was thrown: " + e.getMessage());
}
public void invokeRemoteModel_GETMethodSuccessPath() {
invokeRemoteModel_SuccessPath("GET");
}
@Test
public void invokeRemoteModel_POSTMethodSuccessPath() {
invokeRemoteModel_SuccessPath("POST");
}

@Test
public void invokeRemoteModel_HeaderNull() throws ClientProtocolException, IOException {
invokeRemoteModel_HeaderTest(null);
}

@Test
public void invokeRemoteModel_HeaderNotNull() throws ClientProtocolException, IOException {
Map<String, String> header = new HashMap<>();
header.put("Content-Type", "application/json");
invokeRemoteModel_HeaderTest(header);
}
@Test
public void invokeRemoteModel_POSTMethodErrorPath() {
Expand All @@ -122,7 +152,11 @@ public void invokeRemoteModel_POSTMethodErrorPath() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand All @@ -138,7 +172,12 @@ public void invokeRemoteModel_GETMethodErrorPath() {
.url("wrong url")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand All @@ -153,7 +192,12 @@ public void invokeRemoteModel_WrongHttpMethod() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand Down Expand Up @@ -240,4 +284,5 @@ public void executePredict_TextDocsInput() throws IOException {
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
}

}

0 comments on commit 6a1e3ca

Please sign in to comment.