From ae7c760772ab256ae1e13d30f626e9f954d4df0e Mon Sep 17 00:00:00 2001 From: Viktor Lofgren Date: Fri, 5 Apr 2024 13:30:49 +0200 Subject: [PATCH] (index) Clean up new index query code --- .../model/compiled/CompiledQueryLong.java | 8 + .../model/compiled/CqDataLong.java | 4 + .../aggregate/CompiledQueryAggregates.java | 1 + .../index/index/QueryBranchWalker.java | 74 ++++++--- .../marginalia/index/index/StatefulIndex.java | 153 +++++++++--------- .../marginalia/index/model/SearchTerms.java | 81 ++-------- .../index/query/filter/QueryFilterAllOf.java | 18 ++- .../index/query/filter/QueryFilterAnyOf.java | 41 ++++- .../array/buffer/LongQueryBuffer.java | 6 - 9 files changed, 208 insertions(+), 178 deletions(-) diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java index 639778dc..94fa0e8b 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java @@ -39,4 +39,12 @@ public class CompiledQueryLong implements Iterable { public Iterator iterator() { return stream().iterator(); } + + public long[] copyData() { + return data.copyData(); + } + + public boolean isEmpty() { + return data.size() == 0; + } } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java index 8049631e..24f76b13 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java @@ -24,4 +24,8 @@ public class CqDataLong { public int size() { return data.length; } + + public long[] copyData() { + return Arrays.copyOf(data, data.length); + } } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java index 209acbee..9c4abe72 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.function.*; +/** Contains methods for aggregating across a CompiledQuery or CompiledQueryLong */ public class CompiledQueryAggregates { /** Compiled query aggregate that for a single boolean that treats or-branches as logical OR, * and and-branches as logical AND operations. Will return true if there exists a path through diff --git a/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java b/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java index a465bd86..34b04f0a 100644 --- a/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java +++ b/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java @@ -1,13 +1,18 @@ package nu.marginalia.index.index; import it.unimi.dsi.fastutil.longs.LongArrayList; +import it.unimi.dsi.fastutil.longs.LongArraySet; import it.unimi.dsi.fastutil.longs.LongSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; -class QueryBranchWalker { +/** Helper class for index query construction */ +public class QueryBranchWalker { + private static final Logger logger = LoggerFactory.getLogger(QueryBranchWalker.class); public final long[] priorityOrder; public final List paths; public final long termId; @@ -22,56 +27,81 @@ class QueryBranchWalker { return priorityOrder.length == 0; } + /** Group the provided paths by the lowest termId they contain per the provided priorityOrder, + * into a list of QueryBranchWalkers. This can be performed iteratively on the resultant QBW:s + * to traverse the tree via the next() method. + *

+ * The paths can be extracted through the {@link nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates CompiledQueryAggregates} + * queriesAggregate method. + */ public static List create(long[] priorityOrder, List paths) { + if (paths.isEmpty()) + return List.of(); List ret = new ArrayList<>(); List remainingPaths = new LinkedList<>(paths); - remainingPaths.removeIf(LongSet::isEmpty); + List pathsForPrio = new ArrayList<>(); + for (int i = 0; i < priorityOrder.length; i++) { - long prio = priorityOrder[i]; + long termId = priorityOrder[i]; var it = remainingPaths.iterator(); - List pathsForPrio = new ArrayList<>(); while (it.hasNext()) { var path = it.next(); - if (path.contains(prio)) { - path.remove(prio); + if (path.contains(termId)) { + // Remove the current termId from the path + path.remove(termId); + + // Add it to the set of paths associated with the termId pathsForPrio.add(path); + + // Remove it from consideration it.remove(); } } if (!pathsForPrio.isEmpty()) { - LongArrayList remainingPrios = new LongArrayList(pathsForPrio.size()); - - for (var p : priorityOrder) { - for (var path : pathsForPrio) { - if (path.contains(p)) { - remainingPrios.add(p); - break; - } - } - } - - ret.add(new QueryBranchWalker(remainingPrios.elements(), pathsForPrio, prio)); + long[] newPrios = keepRelevantPriorities(priorityOrder, pathsForPrio); + ret.add(new QueryBranchWalker(newPrios, new ArrayList<>(pathsForPrio), termId)); + pathsForPrio.clear(); } } + // This happens if the priorityOrder array doesn't contain all items in the paths, + // in practice only when an index doesn't contain all the search terms, so we can just + // skip those paths if (!remainingPaths.isEmpty()) { - System.out.println("Dropping: " + remainingPaths); + logger.info("Dropping: {}", remainingPaths); } return ret; } - public List next() { - if (atEnd()) - return List.of(); + /** From the provided priorityOrder array, keep the elements that are present in any set in paths */ + private static long[] keepRelevantPriorities(long[] priorityOrder, List paths) { + LongArrayList remainingPrios = new LongArrayList(paths.size()); + // these sets are typically very small so array set is a good choice + LongSet allElements = new LongArraySet(priorityOrder.length); + for (var path : paths) { + allElements.addAll(path); + } + + for (var p : priorityOrder) { + if (allElements.contains(p)) + remainingPrios.add(p); + } + + return remainingPrios.elements(); + } + + /** Convenience method that applies the create() method + * to the priority order and paths associated with this instance */ + public List next() { return create(priorityOrder, paths); } diff --git a/code/index/java/nu/marginalia/index/index/StatefulIndex.java b/code/index/java/nu/marginalia/index/index/StatefulIndex.java index 0f55c0c8..273da2d0 100644 --- a/code/index/java/nu/marginalia/index/index/StatefulIndex.java +++ b/code/index/java/nu/marginalia/index/index/StatefulIndex.java @@ -4,7 +4,6 @@ import com.google.inject.Inject; import com.google.inject.Singleton; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.longs.LongSet; -import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; import nu.marginalia.index.query.filter.QueryFilterAllOf; import nu.marginalia.index.query.filter.QueryFilterAnyOf; @@ -25,9 +24,7 @@ import java.util.*; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.LongFunction; import java.util.function.Predicate; -import java.util.stream.Collectors; /** This class delegates SearchIndexReader and deals with the stateful nature of the index, * i.e. it may be possible to reconstruct the index and load a new set of data. @@ -95,7 +92,6 @@ public class StatefulIndex { logger.error("Uncaught exception", ex); } finally { - lock.unlock(); } @@ -113,62 +109,6 @@ public class StatefulIndex { return combinedIndexReader != null && combinedIndexReader.isLoaded(); } - private Predicate containsOnly(long[] permitted) { - LongSet permittedTerms = new LongOpenHashSet(permitted); - return permittedTerms::containsAll; - } - - private List createBuilders(CompiledQueryLong query, - LongFunction builderFactory, - long[] termPriority) { - List paths = CompiledQueryAggregates.queriesAggregate(query); - - // Remove any paths that do not contain all prioritized terms, as this means - // the term is missing from the index and can never be found - paths.removeIf(containsOnly(termPriority).negate()); - - List helpers = QueryBranchWalker.create(termPriority, paths); - List builders = new ArrayList<>(); - - for (var helper : helpers) { - var builder = builderFactory.apply(helper.termId); - - builders.add(builder); - - if (helper.atEnd()) - continue; - - var filters = helper.next().stream() - .map(this::createFilter) - .toList(); - - builder.addInclusionFilterAny(filters); - } - - return builders; - } - - private QueryFilterStepIf createFilter(QueryBranchWalker helper) { - var selfCondition = combinedIndexReader.hasWordFull(helper.termId); - if (helper.atEnd()) - return selfCondition; - - var nextSteps = helper.next(); - var nextFilters = nextSteps.stream() - .map(this::createFilter) - .map(filter -> new QueryFilterAllOf(List.of(selfCondition, filter))) - .collect(Collectors.toList()); - - if (nextFilters.isEmpty()) - return selfCondition; - - if (nextFilters.size() == 1) - return nextFilters.getFirst(); - - - return new QueryFilterAnyOf(nextFilters); - } - public List createQueries(SearchTerms terms, QueryParams params) { if (!isLoaded()) { @@ -176,29 +116,99 @@ public class StatefulIndex { return Collections.emptyList(); } - final long[] orderedIncludes = terms.sortedDistinctIncludes(this::compareKeywords); - final long[] orderedIncludesPrio = terms.sortedDistinctIncludes(this::compareKeywordsPrio); - List queryHeads = new ArrayList<>(10); - queryHeads.addAll(createBuilders(terms.compiledQuery(), combinedIndexReader::findFullWord, orderedIncludes)); - queryHeads.addAll(createBuilders(terms.compiledQuery(), combinedIndexReader::findPriorityWord, orderedIncludesPrio)); + final long[] termPriority = terms.sortedDistinctIncludes(this::compareKeywords); + List paths = CompiledQueryAggregates.queriesAggregate(terms.compiledQuery()); - List queries = new ArrayList<>(10); + // Remove any paths that do not contain all prioritized terms, as this means + // the term is missing from the index and can never be found + paths.removeIf(containsAll(termPriority).negate()); + List helpers = QueryBranchWalker.create(termPriority, paths); + + for (var helper : helpers) { + for (var builder : List.of( + combinedIndexReader.findPriorityWord(helper.termId), + combinedIndexReader.findFullWord(helper.termId) + )) + { + queryHeads.add(builder); + + if (helper.atEnd()) + continue; + + List filterSteps = new ArrayList<>(); + for (var step : helper.next()) { + filterSteps.add(createFilter(step, 0)); + } + builder.addInclusionFilterAny(filterSteps); + } + } + + List ret = new ArrayList<>(10); + + // Add additional conditions to the query heads for (var query : queryHeads) { + // Advice terms are a special case, mandatory but not ranked, and exempt from re-writing + for (long term : terms.advice()) { + query = query.alsoFull(term); + } + for (long term : terms.excludes()) { query = query.notFull(term); } // Run these filter steps last, as they'll worst-case cause as many page faults as there are // items in the buffer - queries.add(query.addInclusionFilter(combinedIndexReader.filterForParams(params)).build()); + ret.add(query.addInclusionFilter(combinedIndexReader.filterForParams(params)).build()); } - return queries; + return ret; + } + + /** Recursively create a filter step based on the QBW and its children */ + private QueryFilterStepIf createFilter(QueryBranchWalker walker, int depth) { + final QueryFilterStepIf ownFilterCondition = ownFilterCondition(walker, depth); + + var childSteps = walker.next(); + + if (childSteps.isEmpty()) + return ownFilterCondition; + + List combinedFilters = new ArrayList<>(); + + for (var step : childSteps) { + // Recursion will be limited to a fairly shallow stack depth due to how the queries are constructed. + var childFilter = createFilter(step, depth+1); + combinedFilters.add(new QueryFilterAllOf(ownFilterCondition, childFilter)); + } + + if (combinedFilters.size() == 1) + return combinedFilters.getFirst(); + else + return new QueryFilterAnyOf(combinedFilters); + } + + /** Create a filter condition based on the termId associated with the QBW */ + private QueryFilterStepIf ownFilterCondition(QueryBranchWalker walker, int depth) { + if (depth < 2) { + // At shallow depths we prioritize terms that appear in the priority index, + // to increase the odds we find "good" results before the sand runs out + return new QueryFilterAnyOf( + combinedIndexReader.hasWordPrio(walker.termId), + combinedIndexReader.hasWordFull(walker.termId) + ); + } else { + return combinedIndexReader.hasWordFull(walker.termId); + } + } + + private Predicate containsAll(long[] permitted) { + LongSet permittedTerms = new LongOpenHashSet(permitted); + return permittedTerms::containsAll; } private int compareKeywords(long a, long b) { @@ -208,13 +218,6 @@ public class StatefulIndex { ); } - private int compareKeywordsPrio(long a, long b) { - return Long.compare( - combinedIndexReader.numHitsPrio(a), - combinedIndexReader.numHitsPrio(b) - ); - } - /** Return an array of encoded document metadata longs corresponding to the * document identifiers provided; with metadata for termId. The input array * docs[] *must* be sorted. diff --git a/code/index/java/nu/marginalia/index/model/SearchTerms.java b/code/index/java/nu/marginalia/index/model/SearchTerms.java index 307e4179..8115c109 100644 --- a/code/index/java/nu/marginalia/index/model/SearchTerms.java +++ b/code/index/java/nu/marginalia/index/model/SearchTerms.java @@ -3,54 +3,35 @@ package nu.marginalia.index.model; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongComparator; import it.unimi.dsi.fastutil.longs.LongList; -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; import nu.marginalia.api.searchquery.model.query.SearchQuery; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static nu.marginalia.index.model.SearchTermsUtil.getWordId; public final class SearchTerms { - private final LongList includes; + private final LongList advice; private final LongList excludes; private final LongList priority; private final List coherences; private final CompiledQueryLong compiledQueryIds; - public SearchTerms( - LongList includes, - LongList excludes, - LongList priority, - List coherences, - CompiledQueryLong compiledQueryIds - ) { - this.includes = includes; - this.excludes = excludes; - this.priority = priority; - this.coherences = coherences; + public SearchTerms(SearchQuery query, + CompiledQueryLong compiledQueryIds) + { + this.excludes = new LongArrayList(); + this.priority = new LongArrayList(); + this.coherences = new ArrayList<>(); + this.advice = new LongArrayList(); this.compiledQueryIds = compiledQueryIds; - } - public SearchTerms(SearchQuery query, CompiledQueryLong compiledQueryIds) { - this(new LongArrayList(), - new LongArrayList(), - new LongArrayList(), - new ArrayList<>(), - compiledQueryIds); - - for (var word : query.searchTermsInclude) { - includes.add(getWordId(word)); - } for (var word : query.searchTermsAdvice) { - // This looks like a bug, but it's not - includes.add(getWordId(word)); + advice.add(getWordId(word)); } - for (var coherence : query.searchTermCoherences) { LongList parts = new LongArrayList(coherence.size()); @@ -64,36 +45,29 @@ public final class SearchTerms { for (var word : query.searchTermsExclude) { excludes.add(getWordId(word)); } + for (var word : query.searchTermsPriority) { priority.add(getWordId(word)); } } public boolean isEmpty() { - return includes.isEmpty(); + return compiledQueryIds.isEmpty(); } public long[] sortedDistinctIncludes(LongComparator comparator) { - if (includes.isEmpty()) - return includes.toLongArray(); - - LongList list = new LongArrayList(new LongOpenHashSet(includes)); + LongList list = new LongArrayList(compiledQueryIds.copyData()); list.sort(comparator); return list.toLongArray(); } - public int size() { - return includes.size() + excludes.size() + priority.size(); - } - - public LongList includes() { - return includes; - } public LongList excludes() { return excludes; } - + public LongList advice() { + return advice; + } public LongList priority() { return priority; } @@ -104,29 +78,4 @@ public final class SearchTerms { public CompiledQueryLong compiledQuery() { return compiledQueryIds; } - @Override - public boolean equals(Object obj) { - if (obj == this) return true; - if (obj == null || obj.getClass() != this.getClass()) return false; - var that = (SearchTerms) obj; - return Objects.equals(this.includes, that.includes) && - Objects.equals(this.excludes, that.excludes) && - Objects.equals(this.priority, that.priority) && - Objects.equals(this.coherences, that.coherences); - } - - @Override - public int hashCode() { - return Objects.hash(includes, excludes, priority, coherences); - } - - @Override - public String toString() { - return "SearchTerms[" + - "includes=" + includes + ", " + - "excludes=" + excludes + ", " + - "priority=" + priority + ", " + - "coherences=" + coherences + ']'; - } - } diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java index 8c20fe98..e9725179 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java @@ -2,14 +2,28 @@ package nu.marginalia.index.query.filter; import nu.marginalia.array.buffer.LongQueryBuffer; +import java.util.ArrayList; import java.util.List; import java.util.StringJoiner; public class QueryFilterAllOf implements QueryFilterStepIf { - private final List steps; + private final List steps; public QueryFilterAllOf(List steps) { - this.steps = steps; + this.steps = new ArrayList<>(steps.size()); + + for (var step : steps) { + if (step instanceof QueryFilterAllOf allOf) { + this.steps.addAll(allOf.steps); + } + else { + this.steps.add(step); + } + } + } + + public QueryFilterAllOf(QueryFilterStepIf... steps) { + this(List.of(steps)); } public double cost() { diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java index 2d177645..bea62194 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java @@ -2,14 +2,27 @@ package nu.marginalia.index.query.filter; import nu.marginalia.array.buffer.LongQueryBuffer; +import java.util.ArrayList; import java.util.List; import java.util.StringJoiner; public class QueryFilterAnyOf implements QueryFilterStepIf { - private final List steps; + private final List steps; public QueryFilterAnyOf(List steps) { - this.steps = steps; + this.steps = new ArrayList<>(steps.size()); + + for (var step : steps) { + if (step instanceof QueryFilterAnyOf anyOf) { + this.steps.addAll(anyOf.steps); + } else { + this.steps.add(step); + } + } + } + + public QueryFilterAnyOf(QueryFilterStepIf... steps) { + this(List.of(steps)); } public double cost() { @@ -30,23 +43,37 @@ public class QueryFilterAnyOf implements QueryFilterStepIf { if (steps.isEmpty()) return; + if (steps.size() == 1) { + steps.getFirst().apply(buffer); + return; + } + int start = 0; - int end = buffer.end; + final int endOfValidData = buffer.end; // End of valid data range + + // The filters act as a partitioning function, where anything before buffer.end + // is "in", and is guaranteed to be sorted; and anything after buffer.end is "out" + // but no sorting guaranteed is provided. + + // To provide a conditional filter, we re-sort the "out" range, slice it and apply filtering to the slice for (var step : steps) { - var slice = buffer.slice(start, end); + var slice = buffer.slice(start, endOfValidData); slice.data.quickSort(0, slice.size()); step.apply(slice); start += slice.end; } - buffer.data.quickSort(0, start); - - // Special finalization + // After we're done, read and write pointers should be 0 and "end" should be the length of valid data, + // normally done through buffer.finalizeFiltering(); but that won't work here buffer.reset(); buffer.end = start; + + // After all filters have been applied, we must re-sort all the retained data + // to uphold the sortedness contract + buffer.data.quickSort(0, buffer.end); } public String describe() { diff --git a/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java b/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java index d5b44389..a0312d36 100644 --- a/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java +++ b/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java @@ -133,12 +133,6 @@ public class LongQueryBuffer { write = 0; } - public void finalizeFiltering(int pos) { - end = write; - read = pos; - write = pos; - } - /** Retain only unique values in the buffer, and update the end pointer to the new length. *

* The buffer is assumed to be sorted up until the end pointer.