diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java new file mode 100644 index 0000000000..c23e810090 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class MLAgentDeleteAction extends ActionType { + public static final MLAgentDeleteAction INSTANCE = new MLAgentDeleteAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/delete"; + + private MLAgentDeleteAction() { super(NAME, DeleteResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java new file mode 100644 index 0000000000..ddc568fc60 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentDeleteRequest extends ActionRequest { + @Getter + String agentId; + + @Builder + public MLAgentDeleteRequest(String agentId) { + this.agentId = agentId; + } + + public MLAgentDeleteRequest(StreamInput input) throws IOException { + super(input); + this.agentId = input.readString(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeString(agentId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.agentId == null) { + exception = addValidationError("ML agent id can't be null", exception); + } + + return exception; + } + + public static MLAgentDeleteRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLAgentDeleteRequest) { + return (MLAgentDeleteRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentDeleteRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLAgentDeleteRequest", e); + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java new file mode 100644 index 0000000000..2a61035ce8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; + +public class MLAgentGetAction extends ActionType { + public static final MLAgentGetAction INSTANCE = new MLAgentGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/get"; + + private MLAgentGetAction() { super(NAME, MLAgentGetResponse::new);} + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java new file mode 100644 index 0000000000..4880a07abf --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +public class MLAgentGetRequest extends ActionRequest { + + String agentId; + + @Builder + public MLAgentGetRequest(String agentId) { + this.agentId = agentId; + } + + public MLAgentGetRequest(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.agentId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.agentId == null) { + exception = addValidationError("ML agent id can't be null", exception); + } + + return exception; + } + + public static MLAgentGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLAgentGetRequest) { + return (MLAgentGetRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLAgentGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java new file mode 100644 index 0000000000..a437ef0ed8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Builder; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.agent.MLAgent; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLAgentGetResponse extends ActionResponse implements ToXContentObject { + MLAgent mlAgent; + + @Builder + public MLAgentGetResponse(MLAgent mlAgent) { + this.mlAgent = mlAgent; + } + + public MLAgentGetResponse(StreamInput in) throws IOException { + super(in); + mlAgent = MLAgent.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + mlAgent.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlAgent.toXContent(xContentBuilder, params); + } + + public static MLAgentGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLAgentGetResponse) { + return (MLAgentGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLAgentGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLAgentGetResponse", e); + } + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java new file mode 100644 index 0000000000..7cc9e66793 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLAgentDeleteActionTest { + @Test + public void testMLAgentDeleteActionInstance() { + assertNotNull(MLAgentDeleteAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/delete", MLAgentDeleteAction.NAME); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java new file mode 100644 index 0000000000..135271ec47 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentDeleteRequestTest { + String agentId; + + @Test + public void constructor_AgentId() { + agentId = "test-abc"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(mLAgentDeleteRequest.agentId,agentId); + } + + @Test + public void writeTo() throws IOException { + agentId = "test-hij"; + + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + BytesStreamOutput output = new BytesStreamOutput(); + mLAgentDeleteRequest.writeTo(output); + + MLAgentDeleteRequest mLAgentDeleteRequest1 = new MLAgentDeleteRequest(output.bytes().streamInput()); + + assertEquals(mLAgentDeleteRequest.agentId, mLAgentDeleteRequest1.agentId); + assertEquals(agentId, mLAgentDeleteRequest1.agentId); + } + + @Test + public void validate_Success() { + agentId = "not-null"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + + assertEquals(null, mLAgentDeleteRequest.validate()); + } + + @Test + public void validate_Failure() { + agentId = null; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(null,mLAgentDeleteRequest.agentId); + + ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); + mLAgentDeleteRequest.validate().equals(exception) ; + } + + @Test + public void fromActionRequest() throws IOException { + agentId = "test-lmn"; + MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + assertEquals(mLAgentDeleteRequest.fromActionRequest(mLAgentDeleteRequest), mLAgentDeleteRequest); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java new file mode 100644 index 0000000000..cba838fb02 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLAgentGetActionTest { + + @Test + public void testMLAgentGetActionInstance() { + assertNotNull(MLAgentGetAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/get", MLAgentGetAction.NAME); + } + + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java new file mode 100644 index 0000000000..6a04f5a965 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; + +import java.io.IOException; +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLAgentGetRequestTest { + String agentId; + + @Test + public void constructor_AgentId() { + agentId = "test-abc"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(mLAgentGetRequest.getAgentId(),agentId); + } + + @Test + public void writeTo() throws IOException { + agentId = "test-hij"; + + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + BytesStreamOutput output = new BytesStreamOutput(); + mLAgentGetRequest.writeTo(output); + + MLAgentGetRequest mLAgentGetRequest1 = new MLAgentGetRequest(output.bytes().streamInput()); + + assertEquals(mLAgentGetRequest1.getAgentId(), mLAgentGetRequest.getAgentId()); + assertEquals(mLAgentGetRequest1.getAgentId(), agentId); + } + + @Test + public void validate_Success() { + agentId = "not-null"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + + assertEquals(null, mLAgentGetRequest.validate()); + } + + @Test + public void validate_Failure() { + agentId = null; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(null,mLAgentGetRequest.agentId); + + ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); + mLAgentGetRequest.validate().equals(exception) ; + } + @Test + public void fromActionRequest() throws IOException { + agentId = "test-lmn"; + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + assertEquals(mLAgentGetRequest.fromActionRequest(mLAgentGetRequest), mLAgentGetRequest); + } +} + + diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java new file mode 100644 index 0000000000..7d733a4308 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.*; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.*; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MLAgentGetResponseTest { + + MLAgent mlAgent; + + @Test + public void Create_MLAgentResponse_With_StreamInput() throws IOException { + // Create a BytesStreamOutput to simulate the StreamOutput + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + + //create a test agent using input + bytesStreamOutput.writeString("Test Agent"); + bytesStreamOutput.writeString("flow"); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeBoolean(false); + bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); + bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); + bytesStreamOutput.writeString("test"); + + StreamInput testInputStream = bytesStreamOutput.bytes().streamInput(); + + MLAgentGetResponse mlAgentGetResponse = new MLAgentGetResponse(testInputStream); + MLAgent testMlAgent = mlAgentGetResponse.mlAgent; + assertEquals("flow",testMlAgent.getType()); + assertEquals("Test Agent",testMlAgent.getName()); + assertEquals("test",testMlAgent.getAppType()); + } + + @Test + public void mLAgentGetResponse_Builder() throws IOException { + + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + + assertEquals(mlAgentGetResponse.mlAgent, mlAgent); + } + @Test + public void writeTo() throws IOException { + //create ml agent using MLAgent and mlAgentGetResponse + mlAgent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + //use write out for both agents + BytesStreamOutput output = new BytesStreamOutput(); + mlAgent.writeTo(output); + mlAgentGetResponse.writeTo(output); + MLAgent agent1 = mlAgentGetResponse.mlAgent; + + assertEquals(mlAgent.getAppType(), agent1.getAppType()); + assertEquals(mlAgent.getDescription(), agent1.getDescription()); + assertEquals(mlAgent.getCreatedTime(), agent1.getCreatedTime()); + assertEquals(mlAgent.getName(), agent1.getName()); + assertEquals(mlAgent.getParameters(), agent1.getParameters()); + assertEquals(mlAgent.getType(), agent1.getType()); + } + + @Test + public void toXContent() throws IOException { + mlAgent = new MLAgent("mock", "flow", "test", null, null, null, null, null, null, "test"); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + ToXContent.Params params = EMPTY_PARAMS; + XContentBuilder getResponseXContentBuilder = mlAgentGetResponse.toXContent(builder, params); + assertEquals(getResponseXContentBuilder, mlAgent.toXContent(builder, params)); + } + + @Test + public void FromActionResponse() throws IOException { + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() + .mlAgent(mlAgent) + .build(); + assertEquals(mlAgentGetResponse.fromActionResponse(mlAgentGetResponse), mlAgentGetResponse); + + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java new file mode 100644 index 0000000000..0376fdbba9 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteAgentTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public DeleteAgentTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLAgentDeleteAction.NAME, transportService, actionFilters, MLAgentDeleteRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLAgentDeleteRequest mlAgentDeleteRequest = MLAgentDeleteRequest.fromActionRequest(request); + String agentId = mlAgentDeleteRequest.getAgentId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + DeleteRequest deleteRequest = new DeleteRequest(ML_AGENT_INDEX, agentId); + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + log.debug("Completed Delete Agent Request, agent id:{} deleted", agentId); + wrappedListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete ML Agent " + agentId, e); + wrappedListener.onFailure(e); + } + }); + } catch (Exception e) { + log.error("Failed to delete ml agent " + agentId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java new file mode 100644 index 0000000000..59d17651a3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.ml.common.transport.agent.MLAgentGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetAgentTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public GetAgentTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLAgentGetAction.NAME, transportService, actionFilters, MLAgentGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLAgentGetRequest mlAgentGetRequest = MLAgentGetRequest.fromActionRequest(request); + String agentId = mlAgentGetRequest.getAgentId(); + GetRequest getRequest = new GetRequest(ML_AGENT_INDEX).id(agentId); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Agent Request, id:{}", agentId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + actionListener.onResponse(MLAgentGetResponse.builder().mlAgent(mlAgent).build()); + } catch (Exception e) { + log.error("Failed to parse ml agent" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + log.error("Failed to get agent index", e); + actionListener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML agent " + agentId, e); + actionListener.onFailure(e); + } + }), context::restore)); + } catch (Exception e) { + log.error("Failed to get ML agent " + agentId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 31197cd6be..bbfd5fdfbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -36,6 +36,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.ml.action.agents.DeleteAgentTransportAction; +import org.opensearch.ml.action.agents.GetAgentTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; @@ -90,6 +92,8 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; @@ -161,12 +165,14 @@ import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; import org.opensearch.ml.rest.RestMLDeleteModelGroupAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; +import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetModelAction; import org.opensearch.ml.rest.RestMLGetModelGroupAction; @@ -331,6 +337,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), + new ActionHandler<>(MLAgentGetAction.INSTANCE, GetAgentTransportAction.class), + new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) @@ -589,6 +597,8 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); + RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(); + RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); @@ -631,6 +641,8 @@ public List getRestHandlers( restSearchInteractionsAction, restGetConversationAction, restGetInteractionAction, + restMLGetAgentAction, + restMLDeleteAgentAction, restMemoryUpdateConversationAction, restMemoryUpdateInteractionAction, restMemoryGetTracesAction diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java new file mode 100644 index 0000000000..c8a667055e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteAgentAction.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete ML Agent. + */ +public class RestMLDeleteAgentAction extends BaseRestHandler { + private static final String ML_DELETE_AGENT_ACTION = "ml_delete_agent_action"; + + public void RestMLDeleteAgentAction() {} + + @Override + public String getName() { + return ML_DELETE_AGENT_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/agents/{%s}", ML_BASE_URI, PARAMETER_AGENT_ID))); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String agentId = request.param(PARAMETER_AGENT_ID); + + MLAgentDeleteRequest mlAgentDeleteRequest = new MLAgentDeleteRequest(agentId); + return channel -> client.execute(MLAgentDeleteAction.INSTANCE, mlAgentDeleteRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java new file mode 100644 index 0000000000..efed1d84c3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetAgentAction extends BaseRestHandler { + private static final String ML_GET_Agent_ACTION = "ml_get_agent_action"; + + /** + * Constructor + */ + public RestMLGetAgentAction() {} + + @Override + public String getName() { + return ML_GET_Agent_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/agents/{%s}", ML_BASE_URI, PARAMETER_AGENT_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLAgentGetRequest mlAgentGetRequest = getRequest(request); + return channel -> client.execute(MLAgentGetAction.INSTANCE, mlAgentGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLAgentGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLAgentGetRequest + */ + @VisibleForTesting + MLAgentGetRequest getRequest(RestRequest request) throws IOException { + String agentId = getParameterId(request, PARAMETER_AGENT_ID); + + return new MLAgentGetRequest(agentId); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 98f5f87d22..3a2f9daae4 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -50,6 +50,7 @@ public class RestActionUtils { public static final String PARAMETER_ASYNC = "async"; public static final String PARAMETER_RETURN_CONTENT = "return_content"; public static final String PARAMETER_MODEL_ID = "model_id"; + public static final String PARAMETER_AGENT_ID = "agent_id"; public static final String PARAMETER_TASK_ID = "task_id"; public static final String PARAMETER_CONNECTOR_ID = "connector_id"; public static final String PARAMETER_DEPLOY_MODEL = "deploy"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java new file mode 100644 index 0000000000..212112841a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.agents; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteAgentTransportActionTests { + + @Mock + private Client client; + @Mock + ThreadPool threadPool; + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @InjectMocks + private DeleteAgentTransportAction deleteAgentTransportAction; + + ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + deleteAgentTransportAction = new DeleteAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + @Test + public void testConstructor() { + // Verify that the dependencies were correctly injected + assertEquals(deleteAgentTransportAction.client, client); + assertEquals(deleteAgentTransportAction.xContentRegistry, xContentRegistry); + } + + @Test + public void testDoExecute_Success() { + String agentId = "test-agent-id"; + DeleteResponse deleteResponse = mock(DeleteResponse.class); + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + @Test + public void testDoExecute_Failure() { + String agentId = "test-non-existed-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + NullPointerException NullPointerException = new NullPointerException("Failed to delete ML Agent " + agentId); + listener.onFailure(NullPointerException); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to delete ML Agent " + agentId, argumentCaptor.getValue().getMessage()); + + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java new file mode 100644 index 0000000000..07f406ac07 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -0,0 +1,283 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.agents; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.ml.common.transport.agent.MLAgentGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetAgentTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + @Mock + ThreadPool threadPool; + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @InjectMocks + private GetAgentTransportAction getAgentTransportAction; + + ThreadContext threadContext; + MLAgent mlAgent; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + getAgentTransportAction = new GetAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + } + + @Test + public void testDoExecute_Failure_Get_Agent() { + String agentId = "test-agent-id-no-existed"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new Exception("Failed to get ML agent " + agentId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get ML agent " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_IndexNotFound() { + String agentId = "test-agent-id-IndexNotFound"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new IndexNotFoundException("Failed to get agent index " + agentId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get agent index", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_OpenSearchStatus() throws IOException { + String agentId = "test-agent-id-OpenSearchStatus"; + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + + Task task = mock(Task.class); + + Exception exceptionToThrow = new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exceptionToThrow); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find agent with the provided agent id: " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_RuntimeException() { + String agentId = "test-agent-id-RuntimeException"; + Task task = mock(Task.class); + ActionListener actionListener = mock(ActionListener.class); + + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Failed to get ML agent " + agentId)); + return null; + }).when(client).get(any(), any()); + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get ML agent " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetTask_NullResponse() { + String agentId = "test-agent-id-NullResponse"; + Task task = mock(Task.class); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getAgentTransportAction.doExecute(task, getRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find agent with the provided agent id: " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_Failure_Context_Exception() { + String agentId = "test-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + GetAgentTransportAction getAgentTransportActionNullContext = new GetAgentTransportAction( + transportService, + actionFilters, + client, + xContentRegistry + ); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenThrow(new RuntimeException()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(client).get(any(), any()); + try { + getAgentTransportActionNullContext.doExecute(task, getRequest, actionListener); + } catch (Exception e) { + assertEquals(e.getClass(), RuntimeException.class); + } + } + + @Test + public void testDoExecute_NoAgentId() throws IOException { + GetResponse getResponse = prepareMLAgent(null); + String agentId = "test-agent-id"; + + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + try { + getAgentTransportAction.doExecute(task, request, actionListener); + } catch (Exception e) { + assertEquals(e.getClass(), IllegalArgumentException.class); + } + } + + @Test + public void testDoExecute_Success() throws IOException { + + String agentId = "test-agent-id"; + GetResponse getResponse = prepareMLAgent(agentId); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getAgentTransportAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLAgentGetResponse.class)); + } + + public GetResponse prepareMLAgent(String agentId) throws IOException { + + mlAgent = new MLAgent( + "test", + "test", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test" + ); + + XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", agentId, 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java new file mode 100644 index 0000000000..19849294f8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteAgentActionTests.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLDeleteAgentActionTests extends OpenSearchTestCase { + private RestMLDeleteAgentAction restMLDeleteAgentAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLDeleteAgentAction = new RestMLDeleteAgentAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLAgentDeleteAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteAgentAction mLDeleteAgentAction = new RestMLDeleteAgentAction(); + assertNotNull(mLDeleteAgentAction); + } + + public void testGetName() { + String actionName = restMLDeleteAgentAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_agent_action", actionName); + } + + public void testRoutes() { + List routes = restMLDeleteAgentAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/agents/{agent_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteAgentAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLAgentDeleteRequest.class); + verify(client, times(1)).execute(eq(MLAgentDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String agentId = argumentCaptor.getValue().getAgentId(); + assertEquals(agentId, "agent_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "agent_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java new file mode 100644 index 0000000000..7b2f4eaae8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetAgentActionTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetAgentActionTests extends OpenSearchTestCase { + private RestMLGetAgentAction restMLGetAgentAction; + NodeClient client; + private ThreadPool threadPool; + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLGetAgentAction = new RestMLGetAgentAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLAgentGetAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetAgentAction mLGetAgentAction = new RestMLGetAgentAction(); + assertNotNull(mLGetAgentAction); + } + + public void testGetName() { + String actionName = restMLGetAgentAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_agent_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetAgentAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/agents/{agent_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetAgentAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLAgentGetRequest.class); + verify(client, times(1)).execute(eq(MLAgentGetAction.INSTANCE), argumentCaptor.capture(), any()); + String agentId = argumentCaptor.getValue().getAgentId(); + assertEquals(agentId, "agent_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "agent_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } + +}