diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 174872df1ebd..f94dab426881 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -78,6 +78,8 @@ Optimizations * GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov) +* GITHUB#12372: Reduce allocation during HNSW construction (Jonathan Ellis) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index e2a57a303c61..c3e7a04215ab 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -68,6 +68,8 @@ public final class HnswGraphBuilder { private final RandomAccessVectorValues vectors; private final SplittableRandom random; private final HnswGraphSearcher graphSearcher; + private final NeighborQueue entryCandidates; // for upper levels of graph search + private final NeighborQueue beamCandidates; // for levels of graph where we add the node final OnHeapHnswGraph hnsw; @@ -149,6 +151,8 @@ private HnswGraphBuilder( new FixedBitSet(this.vectors.size())); // in scratch we store candidates in reverse order: worse candidates are first scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); + entryCandidates = new NeighborQueue(1, false); + beamCandidates = new NeighborQueue(beamWidth, false); this.initializedNodes = new HashSet<>(); } @@ -250,7 +254,6 @@ public OnHeapHnswGraph getGraph() { /** Inserts a doc with vector value to the graph */ public void addGraphNode(int node, T value) throws IOException { - NeighborQueue candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; @@ -269,13 +272,19 @@ public void addGraphNode(int node, T value) throws IOException { } // for levels > nodeLevel search with topk = 1 + NeighborQueue candidates = entryCandidates; for (int level = curMaxLevel; level > nodeLevel; level--) { - candidates = graphSearcher.searchLevel(value, 1, level, eps, vectors, hnsw); + candidates.clear(); + graphSearcher.searchLevel( + candidates, value, 1, level, eps, vectors, hnsw, null, Integer.MAX_VALUE); eps = new int[] {candidates.pop()}; } // for levels <= nodeLevel search with topk = beamWidth, and add connections + candidates = beamCandidates; for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { - candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectors, hnsw); + candidates.clear(); + graphSearcher.searchLevel( + candidates, value, beamWidth, level, eps, vectors, hnsw, null, Integer.MAX_VALUE); eps = candidates.nodes(); hnsw.addNode(level, node); addDiverseNeighbors(level, node, candidates); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 5bc718169466..1cd5183a993b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -205,10 +205,12 @@ private static NeighborQueue search( return new NeighborQueue(1, true); } NeighborQueue results; + results = new NeighborQueue(1, false); int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); + results.clear(); + graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); @@ -219,8 +221,9 @@ private static NeighborQueue search( } eps[0] = results.pop(); } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results = new NeighborQueue(topK, false); + graphSearcher.searchLevel( + results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -248,10 +251,19 @@ public NeighborQueue searchLevel( RandomAccessVectorValues vectors, HnswGraph graph) throws IOException { - return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + NeighborQueue results = new NeighborQueue(topK, false); + searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + return results; } - private NeighborQueue searchLevel( + /** + * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE + * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest + * score/comparison value, will be at the top of the heap, while the closest neighbor will be the + * last to be popped. + */ + void searchLevel( + NeighborQueue results, T query, int topK, int level, @@ -261,8 +273,9 @@ private NeighborQueue searchLevel( Bits acceptOrds, int visitedLimit) throws IOException { + assert results.isMinHeap(); + int size = graph.size(); - NeighborQueue results = new NeighborQueue(topK, false); prepareScratchState(vectors.size()); int numVisited = 0; @@ -323,7 +336,6 @@ private NeighborQueue searchLevel( results.pop(); } results.setVisitedCount(numVisited); - return results; } private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException { @@ -365,10 +377,12 @@ int graphNextNeighbor(HnswGraph graph) throws IOException { } /** - * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner. + * This class allows {@link OnHeapHnswGraph} to be searched in a thread-safe manner by avoiding + * the unsafe methods (seek and nextNeighbor, which maintain state in the graph object) and + * instead maintaining the state in the searcher object. * - *

Note the class itself is NOT thread safe, but since each search will create one new graph - * searcher the search method is thread safe. + *

Note the class itself is NOT thread safe, but since each search will create a new Searcher, + * the search methods using this class are thread safe. */ private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java index 582467cd9768..95a20590d678 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java @@ -175,6 +175,10 @@ public void markIncomplete() { this.incomplete = true; } + boolean isMinHeap() { + return order == Order.MIN_HEAP; + } + @Override public String toString() { return "Neighbors[" + heap.size() + "]";