Skip to content

Commit

Permalink
Adding support dataset for question answer model
Browse files Browse the repository at this point in the history
Signed-off-by: TrungBui59 <[email protected]>
  • Loading branch information
TrungBui59 committed Nov 27, 2023
1 parent 7cc9399 commit bc75610
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public enum FunctionName {
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
QUESTION_ANSWER,
REMOTE;

public static FunctionName from(String value) {
Expand All @@ -35,7 +36,8 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING ||
functionName == SPARSE_TOKENIZE || functionName == QUESTION_ANSWER) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME,
TEXT_DOCS,
REMOTE
REMOTE,
QUESTION_ANSWER
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.opensearch.ml.common.dataset;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.QUESTION_ANSWER)
public class QuestionAnswerInputDataSet extends MLInputDataset {
String contextDocs;
List<String> questionsList;

@Builder(toBuilder = true)
public QuestionAnswerInputDataSet(String contextDocs, List<String> questionsList) {
super(MLInputDataType.QUESTION_ANSWER);
this.contextDocs = contextDocs;
this.questionsList = questionsList;
}

public QuestionAnswerInputDataSet(StreamInput in) throws IOException {
super(MLInputDataType.QUESTION_ANSWER);
this.contextDocs = in.readString();
int size = in.readInt();
this.questionsList = new ArrayList<String>(size);
for (int i = 0; i < size; i++) {
questionsList.add(i, in.readString());
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.contextDocs);
out.writeInt(this.questionsList.size());
for (String question : this.questionsList) {
out.writeString(question);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.opensearch.ml.common.input.nlp;

import java.io.IOException;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;

@org.opensearch.ml.common.annotation.MLInput(functionNames=FunctionName.QUESTION_ANSWER)
public class QuestionAnswerMLInput extends MLInput {
public QuestionAnswerMLInput(FunctionName algorithm, MLInputDataset inputDataset) {
super(algorithm, null, inputDataset);
}

public QuestionAnswerMLInput(StreamInput input) throws IOException {
super(input);
}

@Override
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
}
}

0 comments on commit bc75610

Please sign in to comment.