diff --git a/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java b/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java index b6bc42c..5c6a90b 100644 --- a/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java +++ b/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java @@ -1,8 +1,12 @@ package querqy.embeddings; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import querqy.solr.utils.JsonUtil; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.MalformedURLException; @@ -17,6 +21,8 @@ public class ChorusEmbeddingModel implements EmbeddingModel { private static final String CONTENT_TYPE_JSON = "application/json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private URL url; private boolean normalize = true; @@ -65,14 +71,23 @@ public Embedding getEmbedding(final String text) { os.write(input, 0, input.length); } - embedding = Embedding.of((List) JsonUtil.readJson(con.getInputStream(), Map.class).get("embedding")); + embedding = parseEmbeddingFromResponse(con.getInputStream()); embeddingsCache.putEmbedding(cacheKey, embedding); return embedding; } catch (final IOException e) { throw new RuntimeException(e); } + } + public Embedding parseEmbeddingFromResponse(InputStream is) { + try { + JsonNode responseTree = objectMapper.readTree(is); + List embedding = objectMapper.convertValue(responseTree.path("embedding"), new TypeReference<>() {}); + return Embedding.of(embedding); + } catch (IOException e) { + throw new RuntimeException(e); + } } protected String toJsonString(final String text) { @@ -86,4 +101,5 @@ protected String toJsonString(final String text) { ))); } + } diff --git a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java index fbe4cd0..de3558d 100644 --- a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java +++ b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java @@ -10,7 +10,6 @@ import querqy.model.Node; import querqy.model.QuerqyQuery; import querqy.model.Query; -import querqy.model.StringRawQuery; import querqy.model.Term; import querqy.rewrite.QueryRewriter; import querqy.rewrite.RewriterOutput; @@ -77,65 +76,23 @@ public RewriterOutput rewrite(final ExpandedQuery query, } protected ExpandedQuery applyEmbedding(final Embedding embedding, final ExpandedQuery inputQuery) { - + KnnVectorQuery knnVectorQuery = new KnnVectorQuery(vectorField, embedding.asVector(), topK); + LuceneRawQuery luceneRawQuery = new LuceneRawQuery(null, Clause.Occur.MUST,true, knnVectorQuery); switch (queryMode) { case BOOST: - inputQuery.addBoostUpQuery(new BoostQuery(new StringRawQuery(null, makeEmbeddingQueryString(embedding), - Clause.Occur.SHOULD, true), boost)); + inputQuery.addBoostUpQuery(new BoostQuery(luceneRawQuery, boost)); break; case MAIN_QUERY: - // this is a workaround to avoid changing Querqy's query object model for now: - // as we cant set a StringRawQuery as the userQuery, we use a match all for that, add a vector query - // as a filter query (retrieve only knn) and a boost query (rank by distance) - //inputQuery.setUserQuery(new MatchAllQuery()); - inputQuery.setUserQuery(new LuceneRawQuery(null, Clause.Occur.MUST, - true, new KnnVectorQuery(vectorField, embedding.asVector(), topK))); + inputQuery.setUserQuery(luceneRawQuery); break; default: throw new IllegalStateException("Unknown query mode: " + queryMode); - } return inputQuery; } - protected String makeEmbeddingQueryString(final Embedding embedding) { - return "{!func}sum(100,query({!knn f=" + vectorField + " topK=" + topK + " v='[" + embedding.asCommaSeparatedString() + "]'}))"; - } - - protected String embeddingToString(final float[] embedding) { - final StringBuilder sb = new StringBuilder(embedding.length * 16); - for (int i = 0; i < embedding.length; i++) { - if (i > 0) { - sb.append(", "); - } - sb.append(embedding[i]); - } - return sb.toString(); - } - - protected ExpandedQuery applyVectorQuery(final String embeddingQueryString, final ExpandedQuery inputQuery) { - - - final StringRawQuery embeddingsQuery = new StringRawQuery(null, embeddingQueryString, Clause.Occur.SHOULD, true); - switch (queryMode) { - case BOOST: - inputQuery.addBoostUpQuery(new BoostQuery(embeddingsQuery, boost)); - break; - case MAIN_QUERY: - // this is a workaround to avoid changing Querqy's query object model for now: - // as we cant set a StringRawQuery as the userQuery, we use a match all for that, add a vector query - // as a filter query (retrieve only knn) and a boost query (rank by distance) - inputQuery.setUserQuery(new StringRawQuery(null, embeddingQueryString, Clause.Occur.MUST, true)); - break; - default: - throw new IllegalStateException("Unknown query mode: " + queryMode); - - } - - return inputQuery; - } /** * Traverse the query graph, collect all the terms and join them into a string */ diff --git a/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java b/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java new file mode 100644 index 0000000..f87e0fe --- /dev/null +++ b/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java @@ -0,0 +1,19 @@ +package querqy.solr.embeddings; + +import org.junit.Assert; +import org.junit.Test; +import querqy.embeddings.ChorusEmbeddingModel; +import querqy.embeddings.Embedding; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; + +public class ChorusEmbeddingModelTest { + + @Test + public void testParseJson() { + String embeddingJson = "{ \"embedding\": [0.3, 1, 5] }"; + Embedding e = new ChorusEmbeddingModel().parseEmbeddingFromResponse(new ByteArrayInputStream(embeddingJson.getBytes(StandardCharsets.UTF_8))); + Assert.assertArrayEquals(e.asVector(), new float[] { 0.3f, 1f, 5f}, 0f); + } +}