forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding support dataset for question answer model
Signed-off-by: TrungBui59 <[email protected]>
- Loading branch information
1 parent
7cc9399
commit bc75610
Showing
4 changed files
with
79 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,5 +9,6 @@ public enum MLInputDataType { | |
SEARCH_QUERY, | ||
DATA_FRAME, | ||
TEXT_DOCS, | ||
REMOTE | ||
REMOTE, | ||
QUESTION_ANSWER | ||
} |
49 changes: 49 additions & 0 deletions
49
common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnswerInputDataSet.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package org.opensearch.ml.common.dataset; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.annotation.InputDataSet; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.FieldDefaults; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@InputDataSet(MLInputDataType.QUESTION_ANSWER) | ||
public class QuestionAnswerInputDataSet extends MLInputDataset { | ||
String contextDocs; | ||
List<String> questionsList; | ||
|
||
@Builder(toBuilder = true) | ||
public QuestionAnswerInputDataSet(String contextDocs, List<String> questionsList) { | ||
super(MLInputDataType.QUESTION_ANSWER); | ||
this.contextDocs = contextDocs; | ||
this.questionsList = questionsList; | ||
} | ||
|
||
public QuestionAnswerInputDataSet(StreamInput in) throws IOException { | ||
super(MLInputDataType.QUESTION_ANSWER); | ||
this.contextDocs = in.readString(); | ||
int size = in.readInt(); | ||
this.questionsList = new ArrayList<String>(size); | ||
for (int i = 0; i < size; i++) { | ||
questionsList.add(i, in.readString()); | ||
} | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.contextDocs); | ||
out.writeInt(this.questionsList.size()); | ||
for (String question : this.questionsList) { | ||
out.writeString(question); | ||
} | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package org.opensearch.ml.common.input.nlp; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
@org.opensearch.ml.common.annotation.MLInput(functionNames=FunctionName.QUESTION_ANSWER) | ||
public class QuestionAnswerMLInput extends MLInput { | ||
public QuestionAnswerMLInput(FunctionName algorithm, MLInputDataset inputDataset) { | ||
super(algorithm, null, inputDataset); | ||
} | ||
|
||
public QuestionAnswerMLInput(StreamInput input) throws IOException { | ||
super(input); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput output) throws IOException { | ||
super.writeTo(output); | ||
} | ||
} |