From 941a5c26cd9c5aed6e162ac56b2711a2391b1c99 Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Tue, 7 Jan 2025 10:05:40 -0800 Subject: [PATCH] Use iterator that can advance, add random value unit test Signed-off-by: bowenlan-amzn --- .../index/mapper/NumberFieldMapper.java | 3 +- .../search/query/BitmapIndexQuery.java | 58 ++++++++++--------- .../query/BitmapDocValuesQueryTests.java | 6 -- .../search/query/BitmapIndexQueryTests.java | 49 ++++++++++++++-- 4 files changed, 77 insertions(+), 39 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java index b0f13059dd7af..425c5c5efe50b 100644 --- a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java @@ -80,6 +80,7 @@ import org.opensearch.search.DocValueFormat; import org.opensearch.search.lookup.SearchLookup; import org.opensearch.search.query.BitmapDocValuesQuery; +import org.opensearch.search.query.BitmapIndexQuery; import java.io.IOException; import java.math.BigInteger; @@ -97,7 +98,6 @@ import java.util.function.Function; import java.util.function.Supplier; -import org.opensearch.search.query.BitmapIndexQuery; import org.roaringbitmap.RoaringBitmap; /** @@ -1555,6 +1555,7 @@ public Scorer get(long leadCost) throws IOException { final BytesRef encoded = new BytesRef(new byte[Integer.BYTES]); Query query = new PointInSetQuery(field, 1, Integer.BYTES, new PointInSetQuery.Stream() { final Iterator iterator = bitmap.iterator(); + @Override public BytesRef next() { int value; diff --git a/server/src/main/java/org/opensearch/search/query/BitmapIndexQuery.java b/server/src/main/java/org/opensearch/search/query/BitmapIndexQuery.java index f403764bbe734..9aa355541f291 100644 --- a/server/src/main/java/org/opensearch/search/query/BitmapIndexQuery.java +++ b/server/src/main/java/org/opensearch/search/query/BitmapIndexQuery.java @@ -29,12 +29,13 @@ import org.apache.lucene.util.BytesRefIterator; import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.RamUsageEstimator; -import org.roaringbitmap.RoaringBitmap; import java.io.IOException; -import java.util.Iterator; import java.util.Objects; +import org.roaringbitmap.PeekableIntIterator; +import org.roaringbitmap.RoaringBitmap; + /** * A query that matches all documents that contain a set of integer numbers represented by bitmap * @@ -50,12 +51,19 @@ public BitmapIndexQuery(String field, RoaringBitmap bitmap) { this.field = field; } - private static BytesRefIterator bitmapEncodedIterator(RoaringBitmap bitmap) { - return new BytesRefIterator() { - private final Iterator iterator = bitmap.iterator(); + interface BitmapIterator extends BytesRefIterator { + // wrap IntIterator.next() + BytesRef next(); + + // expose PeekableIntIterator.advanceIfNeeded, advance as long as the next value is smaller than target + void advance(byte[] target); + } + + private static BitmapIterator bitmapEncodedIterator(RoaringBitmap bitmap) { + return new BitmapIterator() { + private final PeekableIntIterator iterator = bitmap.getIntIterator(); private final BytesRef encoded = new BytesRef(new byte[Integer.BYTES]); - @Override public BytesRef next() { int value; if (iterator.hasNext()) { @@ -66,6 +74,10 @@ public BytesRef next() { IntPoint.encodeDimension(value, encoded.bytes, 0); return encoded; } + + public void advance(byte[] target) { + iterator.advanceIfNeeded(IntPoint.decodeDimension(target, 0)); + } }; } @@ -85,8 +97,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { final Weight weight = this; LeafReader reader = context.reader(); - // get point value - // only works for one dimension + // get the point value which should be one dimension, since bitmap saves integers PointValues values = reader.getPointValues(field); if (values == null) { return null; @@ -118,7 +129,7 @@ public long cost() { @Override public boolean isCacheable(LeafReaderContext ctx) { - // This query depend only on segment-immutable structure points + // This query depend only on segment-immutable structure — points return true; } }; @@ -126,13 +137,12 @@ public boolean isCacheable(LeafReaderContext ctx) { private class MergePointVisitor implements PointValues.IntersectVisitor { private final DocIdSetBuilder result; - private final BytesRefIterator iterator; + private final BitmapIterator iterator; private BytesRef nextQueryPoint; private final ArrayUtil.ByteArrayComparator comparator; private DocIdSetBuilder.BulkAdder adder; - public MergePointVisitor(DocIdSetBuilder result) - throws IOException { + public MergePointVisitor(DocIdSetBuilder result) throws IOException { this.result = result; this.comparator = ArrayUtil.getUnsignedComparator(Integer.BYTES); this.iterator = bitmapEncodedIterator(bitmap); @@ -175,11 +185,8 @@ private boolean matches(byte[] packedValue) { return true; } else if (cmp < 0) { // Query point is before index point, so we move to next query point - try { - nextQueryPoint = iterator.next(); - } catch (IOException e) { - throw new RuntimeException(e); - } + iterator.advance(packedValue); + nextQueryPoint = iterator.next(); } else { // Query point is after index point, so we don't collect and we return: break; @@ -191,19 +198,14 @@ private boolean matches(byte[] packedValue) { @Override public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { while (nextQueryPoint != null) { - int cmpMin = - comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, minPackedValue, 0); + int cmpMin = comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, minPackedValue, 0); if (cmpMin < 0) { // query point is before the start of this cell - try { - nextQueryPoint = iterator.next(); - } catch (IOException e) { - throw new RuntimeException(e); - } + iterator.advance(minPackedValue); + nextQueryPoint = iterator.next(); continue; } - int cmpMax = - comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, maxPackedValue, 0); + int cmpMax = comparator.compare(nextQueryPoint.bytes, nextQueryPoint.offset, maxPackedValue, 0); if (cmpMax > 0) { // query point is after the end of this cell return PointValues.Relation.CELL_OUTSIDE_QUERY; @@ -260,7 +262,7 @@ public int hashCode() { @Override public long ramBytesUsed() { - return RamUsageEstimator.shallowSizeOfInstance(BitmapIndexQuery.class) + RamUsageEstimator.sizeOfObject(field) - + RamUsageEstimator.sizeOfObject(bitmap); + return RamUsageEstimator.shallowSizeOfInstance(BitmapIndexQuery.class) + RamUsageEstimator.sizeOfObject(field) + RamUsageEstimator + .sizeOfObject(bitmap); } } diff --git a/server/src/test/java/org/opensearch/search/query/BitmapDocValuesQueryTests.java b/server/src/test/java/org/opensearch/search/query/BitmapDocValuesQueryTests.java index d3e43e5f63979..d103b335588bc 100644 --- a/server/src/test/java/org/opensearch/search/query/BitmapDocValuesQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/BitmapDocValuesQueryTests.java @@ -12,14 +12,9 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.IntField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.DocValues; import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.opensearch.test.OpenSearchTestCase; @@ -28,7 +23,6 @@ import java.io.IOException; import java.util.HashSet; -import java.util.LinkedList; import java.util.List; import java.util.Set; diff --git a/server/src/test/java/org/opensearch/search/query/BitmapIndexQueryTests.java b/server/src/test/java/org/opensearch/search/query/BitmapIndexQueryTests.java index 13e29f78a1d5f..b5f7e35e92520 100644 --- a/server/src/test/java/org/opensearch/search/query/BitmapIndexQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/BitmapIndexQueryTests.java @@ -23,17 +23,22 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.opensearch.common.Randomness; +import org.opensearch.test.OpenSearchTestCase; import org.junit.After; import org.junit.Before; -import org.opensearch.test.OpenSearchTestCase; -import org.roaringbitmap.RoaringBitmap; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Random; import java.util.Set; +import org.roaringbitmap.RoaringBitmap; + public class BitmapIndexQueryTests extends OpenSearchTestCase { private Directory dir; private IndexWriter w; @@ -86,11 +91,11 @@ public void testScore() throws IOException { assertEquals(expected, actual); } + // use doc values to get the actual value of the matching docs + // cannot directly check the docId because test can randomize segment numbers static List getMatchingValues(Weight weight, IndexReader reader) throws IOException { List actual = new LinkedList<>(); for (LeafReaderContext leaf : reader.leaves()) { - // use doc values to get the actual value of the matching docs and assert - // cannot directly check the docId because test can randomize segment numbers SortedNumericDocValues dv = DocValues.getSortedNumeric(leaf.reader(), "product_id"); Scorer scorer = weight.scorer(leaf); DocIdSetIterator disi = scorer.iterator(); @@ -138,4 +143,40 @@ public void testScoreMutilValues() throws IOException { assertEquals(expected, actual); } + public void testRandomDocumentsAndQueries() throws IOException { + Random random = Randomness.get(); + int valueRange = 10_000; // the range of query values should be within indexed values + + for (int i = 0; i < valueRange + 1; i++) { + Document d = new Document(); + d.add(new IntField("product_id", i, Field.Store.NO)); + w.addDocument(d); + } + + w.commit(); + reader = DirectoryReader.open(w); + searcher = newSearcher(reader); + + // Generate random values for bitmap query + Set queryValues = new HashSet<>(); + int numberOfValues = 5; + for (int i = 0; i < numberOfValues; i++) { + int value = random.nextInt(valueRange) + 1; + queryValues.add(value); + } + RoaringBitmap bitmap = new RoaringBitmap(); + bitmap.add(queryValues.stream().mapToInt(Integer::intValue).toArray()); + + BitmapIndexQuery query = new BitmapIndexQuery("product_id", bitmap); + Weight weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + + Set actualSet = new HashSet<>(getMatchingValues(weight, searcher.getIndexReader())); + + List expected = new ArrayList<>(queryValues); + Collections.sort(expected); + List actual = new ArrayList<>(actualSet); + Collections.sort(actual); + assertEquals(expected, actual); + } + }