From bc75610ec32b24f23ba2b777c2b48de701983459 Mon Sep 17 00:00:00 2001 From: TrungBui59 Date: Mon, 27 Nov 2023 11:07:30 -0500 Subject: [PATCH] Adding support dataset for question answer model Signed-off-by: TrungBui59 --- .../opensearch/ml/common/FunctionName.java | 4 +- .../ml/common/dataset/MLInputDataType.java | 3 +- .../dataset/QuestionAnswerInputDataSet.java | 49 +++++++++++++++++++ .../input/nlp/QuestionAnswerMLInput.java | 25 ++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnswerInputDataSet.java create mode 100644 common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 96c6a58235..3544e6f1c7 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -20,6 +20,7 @@ public enum FunctionName { SPARSE_ENCODING, SPARSE_TOKENIZE, METRICS_CORRELATION, + QUESTION_ANSWER, REMOTE; public static FunctionName from(String value) { @@ -35,7 +36,8 @@ public static FunctionName from(String value) { * @return true for deep learning model. */ public static boolean isDLModel(FunctionName functionName) { - if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) { + if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || + functionName == SPARSE_TOKENIZE || functionName == QUESTION_ANSWER) { return true; } return false; diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java index 46cdb161bd..6e95e221ff 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java @@ -9,5 +9,6 @@ public enum MLInputDataType { SEARCH_QUERY, DATA_FRAME, TEXT_DOCS, - REMOTE + REMOTE, + QUESTION_ANSWER } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnswerInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnswerInputDataSet.java new file mode 100644 index 0000000000..9359e3e8e8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnswerInputDataSet.java @@ -0,0 +1,49 @@ +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.annotation.InputDataSet; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@InputDataSet(MLInputDataType.QUESTION_ANSWER) +public class QuestionAnswerInputDataSet extends MLInputDataset { + String contextDocs; + List questionsList; + + @Builder(toBuilder = true) + public QuestionAnswerInputDataSet(String contextDocs, List questionsList) { + super(MLInputDataType.QUESTION_ANSWER); + this.contextDocs = contextDocs; + this.questionsList = questionsList; + } + + public QuestionAnswerInputDataSet(StreamInput in) throws IOException { + super(MLInputDataType.QUESTION_ANSWER); + this.contextDocs = in.readString(); + int size = in.readInt(); + this.questionsList = new ArrayList(size); + for (int i = 0; i < size; i++) { + questionsList.add(i, in.readString()); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.contextDocs); + out.writeInt(this.questionsList.size()); + for (String question : this.questionsList) { + out.writeString(question); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java new file mode 100644 index 0000000000..d9c553855f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java @@ -0,0 +1,25 @@ +package org.opensearch.ml.common.input.nlp; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.input.MLInput; + +@org.opensearch.ml.common.annotation.MLInput(functionNames=FunctionName.QUESTION_ANSWER) +public class QuestionAnswerMLInput extends MLInput { + public QuestionAnswerMLInput(FunctionName algorithm, MLInputDataset inputDataset) { + super(algorithm, null, inputDataset); + } + + public QuestionAnswerMLInput(StreamInput input) throws IOException { + super(input); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + } +}