Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

检索API v1.1 SDK更新 #722

Merged
merged 4 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions go/appbuilder/dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
package appbuilder

import (
"bytes"
"fmt"
// "bytes"
// "fmt"
"os"
"testing"
)
Expand Down Expand Up @@ -69,41 +69,41 @@ func TestDatasetError(t *testing.T) {
dataset, _ = NewDataset(config)

}
func TestDataset(t *testing.T) {
t.Parallel() // 并发运行
// 创建缓冲区来存储日志
var logBuffer bytes.Buffer

// 定义一个日志函数,将日志写入缓冲区
log := func(format string, args ...interface{}) {
fmt.Fprintf(&logBuffer, format+"\n", args...)
}

// 测试逻辑
config, err := NewSDKConfig("", os.Getenv(SecretKeyV3))
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new http client config failed: %v", err)
}
dataset, _ := NewDataset(config)
datasetID, err := dataset.Create("测试集合")
if err != nil {
datasetID = os.Getenv(SecretKeyV3)
}

_, err = dataset.ListDocument(datasetID, 1, 10, "")
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("list document failed: %v", err)
}
log("Listed documents for dataset ID: %s", datasetID)

// 如果测试失败,则输出缓冲区中的日志
if t.Failed() {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
fmt.Println(logBuffer.String())
} else { // else 紧跟在右大括号后面
// 测试通过,打印文件名和测试函数名
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}
// func TestDataset(t *testing.T) {
// t.Parallel() // 并发运行
// // 创建缓冲区来存储日志
// var logBuffer bytes.Buffer

// // 定义一个日志函数,将日志写入缓冲区
// log := func(format string, args ...interface{}) {
// fmt.Fprintf(&logBuffer, format+"\n", args...)
// }

// // 测试逻辑
// config, err := NewSDKConfig("", os.Getenv(SecretKeyV3))
// if err != nil {
// t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
// t.Fatalf("new http client config failed: %v", err)
// }
// dataset, _ := NewDataset(config)
// datasetID, err := dataset.Create("测试集合")
// if err != nil {
// datasetID = os.Getenv(SecretKeyV3)
// }

// _, err = dataset.ListDocument(datasetID, 1, 10, "")
// if err != nil {
// t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
// t.Fatalf("list document failed: %v", err)
// }
// log("Listed documents for dataset ID: %s", datasetID)

// // 如果测试失败,则输出缓冲区中的日志
// if t.Failed() {
// t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
// fmt.Println(logBuffer.String())
// } else { // else 紧跟在右大括号后面
// // 测试通过,打印文件名和测试函数名
// t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
// }
// }
6 changes: 6 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,12 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
// 检查 RankScoreThreshold 是否为 nil,如果是,则设置默认值
if req.RankScoreThreshold == nil {
defaultThreshold := 0.4
req.RankScoreThreshold = &defaultThreshold
}

request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query")
Expand Down
64 changes: 44 additions & 20 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

package appbuilder

type QueryType string

const (
Fulltext QueryType = "fulltext"
Semantic QueryType = "semantic"
Hybrid QueryType = "hybrid"
)

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -277,6 +285,18 @@ type ElasticSearchRetrieveConfig struct {
Top int `json:"top"`
}

type VectorDBRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type SmallToBigConfig struct {
Name string `json:"name"`
Type string `json:"type"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Expand All @@ -291,13 +311,14 @@ type QueryPipelineConfig struct {
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *string `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *QueryType `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
RankScoreThreshold *float64 `json:"rank_score_threshold,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
}

type RowLine struct {
Expand All @@ -314,19 +335,22 @@ type ChunkLocation struct {
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
NeighbourChunks []Chunk `json:"neighbour_chunks"`
OriginalChunkId string `json:"original_chunk_id"`
OriginalChunkOffset int `json:"original_chunk_offset"`
}

type QueryKnowledgeBaseResponse struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,9 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I

public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
if (request.getRank_score_threshold() == null) {
request.setRank_score_threshold(0.4f);
}
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
Expand All @@ -733,13 +736,36 @@ public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest r
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Integer top, Integer skip,
String[] knowledgebaseIDs, QueryKnowledgeBaseRequest.MetadataFilters filters,
QueryKnowledgeBaseRequest.QueryPipelineConfig pipelineConfig)
throws IOException, AppBuilderServerException {

float rank_score_threshold = 0.4f;

String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, rank_score_threshold, top, skip, knowledgebaseIDs, filters, pipelineConfig);
String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest,
QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Float rank_score_threshold, Integer top, Integer skip,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不行,这个会导致升级的原代码没发用。建议重新加一个方法,原方法用默认值

String[] knowledgebaseIDs, QueryKnowledgeBaseRequest.MetadataFilters filters,
QueryKnowledgeBaseRequest.QueryPipelineConfig pipelineConfig)
throws IOException, AppBuilderServerException {
if (rank_score_threshold == null) {
rank_score_threshold = 0.4f;
}

String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, top, skip, knowledgebaseIDs, filters, pipelineConfig);
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, rank_score_threshold, top, skip, knowledgebaseIDs, filters, pipelineConfig);
String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
public class QueryKnowledgeBaseRequest {
private String query;
private String type;
private Float rank_score_threshold;
private Integer top;
private Integer skip;
private String[] knowledgebase_ids;
private MetadataFilters metadata_filters;
private QueryPipelineConfig pipeline_config;

public QueryKnowledgeBaseRequest(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseRequest(String query, String type, Float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebase_ids, MetadataFilters metadata_filters,
QueryPipelineConfig pipeline_config) {
this.query = query;
this.type = type;
this.rank_score_threshold = rank_score_threshold;
this.top = top;
this.skip = skip;
this.knowledgebase_ids = knowledgebase_ids;
Expand All @@ -39,6 +41,14 @@ public void setType(String type) {
this.type = type;
}

public Float getRank_score_threshold() {
return rank_score_threshold;
}

public void setRank_score_threshold(Float rank_score_threshold) {
this.rank_score_threshold = rank_score_threshold;
}

public Integer getTop() {
return top;
}
Expand Down Expand Up @@ -217,6 +227,46 @@ public void setTop(Integer top) {
}
}

public static class VectorDBRetrieveConfig {
private String name;
private String type;
private Double threshold;
private Integer top;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

public Double getThreshold() {
return threshold;
}

public void setThreshold(Double threshold) {
this.threshold = threshold;
}

public Integer getTop() {
return top;
}

public void setTop(Integer top) {
this.top = top;
}
}


public static class RankingConfig {
private String name;
private String type;
Expand Down Expand Up @@ -265,6 +315,28 @@ public void setTop(Integer top) {
}
}

public static class SmallToBigConfig {
private String name;
private String type;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

}

public static class QueryPipelineConfig {
private String id;
private List<Object> pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public static class Chunk {
private float rank_score;
private ChunkLocation locations;
private List<Chunk> children;
private List<Chunk> neighbour_chunks;
private String original_chunk_id;
private Integer original_chunk_offset;

public String getChunk_id() { return chunk_id; }

Expand Down Expand Up @@ -96,6 +99,18 @@ public static class Chunk {
public List<Chunk> getChildren() { return children; }

public void setChildren(List<Chunk> children) { this.children = children; }

public List<Chunk> getNeighbour_chunks() { return neighbour_chunks; }

public void setNeighbour_chunks(List<Chunk> neighbour_chunks) { this.neighbour_chunks = neighbour_chunks; }

public String getOriginal_chunk_id() { return original_chunk_id; }

public void setOriginal_chunk_id(String original_chunk_id) { this.original_chunk_id = original_chunk_id; }

public Integer getOriginal_chunk_offset() { return original_chunk_offset; }

public void setOriginal_chunk_offset(Integer original_chunk_offset) { this.original_chunk_offset = original_chunk_offset; }
}

public static class ChunkLocation {
Expand Down
Loading
Loading