Skip to content

Commit

Permalink
adding mlmodeltool and agent tool with tests (opensearch-project#1768)
Browse files Browse the repository at this point in the history
* adding mlmodeltool and agent tool with tests

Signed-off-by: Dhrubo Saha <[email protected]>

* updating tests

Signed-off-by: Dhrubo Saha <[email protected]>

* removed connector

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Dec 16, 2023
1 parent bba62af commit bab9439
Show file tree
Hide file tree
Showing 9 changed files with 730 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public enum FunctionName {
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;
REMOTE,
AGENT;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.execute.agent;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.AGENT})
public class AgentMLInput extends MLInput {
public static final String AGENT_ID_FIELD = "agent_id";
public static final String PARAMETERS_FIELD = "parameters";

@Getter @Setter
private String agentId;

@Builder(builderMethodName = "AgentMLInputBuilder")
public AgentMLInput(String agentId, FunctionName functionName, MLInputDataset inputDataset) {
this.agentId = agentId;
this.algorithm = functionName;
this.inputDataset = inputDataset;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(agentId);
}

public AgentMLInput(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case AGENT_ID_FIELD:
agentId = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
default:
parser.skipChildren();
break;
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatch
this.user = user;
}

public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
this(modelId, mlInput, true, null);
}

public MLPredictionTaskRequest(String modelId, MLInput mlInput, User user) {
this(modelId, mlInput, true, user);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.execute.agent;

import org.junit.Test;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class AgentMLInputTests {

@Test
public void testConstructorWithAgentIdFunctionNameAndDataset() {
// Arrange
String agentId = "testAgentId";
FunctionName functionName = FunctionName.AGENT; // Assuming FunctionName is an enum or similar
MLInputDataset dataset = mock(MLInputDataset.class); // Mock the MLInputDataset

// Act
AgentMLInput input = new AgentMLInput(agentId, functionName, dataset);

// Assert
assertEquals(agentId, input.getAgentId());
assertEquals(functionName, input.getAlgorithm());
assertEquals(dataset, input.getInputDataset());
}

@Test
public void testWriteTo() throws IOException {
// Arrange
String agentId = "testAgentId";
AgentMLInput input = new AgentMLInput(agentId, FunctionName.AGENT, null);
StreamOutput out = mock(StreamOutput.class);

// Act
input.writeTo(out);

// Assert
verify(out).writeString(agentId);
}

@Test
public void testConstructorWithStreamInput() throws IOException {
// Arrange
String agentId = "testAgentId";
StreamInput in = mock(StreamInput.class);
when(in.readString()).thenReturn(agentId);

// Act
AgentMLInput input = new AgentMLInput(in);

// Assert
assertEquals(agentId, input.getAgentId());
}

@Test
public void testConstructorWithXContentParser() throws IOException {
// Arrange
XContentParser parser = mock(XContentParser.class);

// Simulate parser behavior for START_OBJECT token
when(parser.currentToken()).thenReturn(XContentParser.Token.START_OBJECT);
when(parser.nextToken()).thenReturn(XContentParser.Token.FIELD_NAME)
.thenReturn(XContentParser.Token.VALUE_STRING)
.thenReturn(XContentParser.Token.FIELD_NAME) // For PARAMETERS_FIELD
.thenReturn(XContentParser.Token.START_OBJECT) // Start of PARAMETERS_FIELD map
.thenReturn(XContentParser.Token.FIELD_NAME) // Key in PARAMETERS_FIELD map
.thenReturn(XContentParser.Token.VALUE_STRING) // Value in PARAMETERS_FIELD map
.thenReturn(XContentParser.Token.END_OBJECT) // End of PARAMETERS_FIELD map
.thenReturn(XContentParser.Token.END_OBJECT); // End of the main object

// Simulate parser behavior for agent_id
when(parser.currentName()).thenReturn("agent_id")
.thenReturn("parameters")
.thenReturn("paramKey");
when(parser.text()).thenReturn("testAgentId")
.thenReturn("paramValue");

// Simulate parser behavior for parameters
Map<String, Object> paramMap = new HashMap<>();
paramMap.put("paramKey", "paramValue");
when(parser.map()).thenReturn(paramMap);

// Act
AgentMLInput input = new AgentMLInput(parser, FunctionName.AGENT);

// Assert
assertEquals("testAgentId", input.getAgentId());
assertNotNull(input.getInputDataset());
assertTrue(input.getInputDataset() instanceof RemoteInferenceInputDataSet);
// Additional assertions for RemoteInferenceInputDataSet
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) input.getInputDataset();
assertEquals("paramValue", dataset.getParameters().get("paramKey"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.ml.common.dataframe.ColumnType;
Expand Down Expand Up @@ -53,9 +54,11 @@ public void setUp() {

@Test
public void writeTo_Success() throws IOException {
User user = User.parse("admin|role-1|all_access");

MLPredictionTaskRequest request = MLPredictionTaskRequest.builder()
.mlInput(mlInput)
.user(user)
.build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
Expand All @@ -73,13 +76,18 @@ public void writeTo_Success() throws IOException {
assertEquals(1, dataFrame.getRow(0).size());
assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());

User userExpect = request.getUser();
assertEquals(user.getName(), userExpect.getName());

assertNull(request.getModelId());
}

@Test
public void validate_Success() {
User user = User.parse("admin|role-1|all_access");
MLPredictionTaskRequest request = MLPredictionTaskRequest.builder()
.mlInput(mlInput)
.user(user)
.build();

assertNull(request.validate());
Expand Down Expand Up @@ -133,8 +141,10 @@ public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_SearchQuery
}

private void fromActionRequest_Success_WithNonMLPredictionTaskRequest(MLInput mlInput) {
User user = User.parse("admin|role-1|all_access");
MLPredictionTaskRequest request = MLPredictionTaskRequest.builder()
.mlInput(mlInput)
.user(user)
.build();
ActionRequest actionRequest = new ActionRequest() {
@Override
Expand All @@ -151,6 +161,7 @@ public void writeTo(StreamOutput out) throws IOException {
assertNotSame(result, request);
assertEquals(request.getMlInput().getAlgorithm(), result.getMlInput().getAlgorithm());
assertEquals(request.getMlInput().getInputDataset().getInputDataType(), result.getMlInput().getInputDataset().getInputDataType());
assertEquals(request.getUser().getName(), request.getUser().getName());
}

@Test(expected = UncheckedIOException.class)
Expand All @@ -168,4 +179,4 @@ public void writeTo(StreamOutput out) throws IOException {
};
MLPredictionTaskRequest.fromActionRequest(actionRequest);
}
}
}
Loading

0 comments on commit bab9439

Please sign in to comment.