diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/ResultRankingContext.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/ResultRankingContext.java index 9052345a..01c017f0 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/ResultRankingContext.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/ResultRankingContext.java @@ -3,11 +3,17 @@ package nu.marginalia.api.searchquery.model.results; import lombok.ToString; import nu.marginalia.api.searchquery.model.compiled.CqDataInt; +import java.util.BitSet; + @ToString public class ResultRankingContext { private final int docCount; public final ResultRankingParameters params; + + public final BitSet regularMask; + public final BitSet ngramsMask; + /** CqDataInt associated with frequency information of the terms in the query * in the full index. The dataset is indexed by the compiled query. */ public final CqDataInt fullCounts; @@ -18,11 +24,18 @@ public class ResultRankingContext { public ResultRankingContext(int docCount, ResultRankingParameters params, + BitSet ngramsMask, CqDataInt fullCounts, CqDataInt prioCounts) { this.docCount = docCount; this.params = params; + + this.ngramsMask = ngramsMask; + + this.regularMask = new BitSet(ngramsMask.length()); + this.regularMask.xor(ngramsMask); + this.fullCounts = fullCounts; this.priorityCounts = prioCounts; } diff --git a/code/index/java/nu/marginalia/index/IndexGrpcService.java b/code/index/java/nu/marginalia/index/IndexGrpcService.java index 3eb2f5d7..50fb1eb8 100644 --- a/code/index/java/nu/marginalia/index/IndexGrpcService.java +++ b/code/index/java/nu/marginalia/index/IndexGrpcService.java @@ -9,6 +9,7 @@ import io.prometheus.client.Histogram; import it.unimi.dsi.fastutil.longs.LongArrayList; import lombok.SneakyThrows; import nu.marginalia.api.searchquery.*; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; import nu.marginalia.api.searchquery.model.compiled.CqDataInt; import nu.marginalia.api.searchquery.model.query.SearchSpecification; @@ -204,7 +205,9 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { return new SearchResultSet(List.of()); } - ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.compiledQueryIds); + ResultRankingContext rankingContext = createRankingContext(params.rankingParams, + params.compiledQuery, + params.compiledQueryIds); var queryExecution = new QueryExecution(rankingContext, params.fetchSize); @@ -415,20 +418,28 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { } private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams, + CompiledQuery compiledQuery, CompiledQueryLong compiledQueryIds) { int[] full = new int[compiledQueryIds.size()]; int[] prio = new int[compiledQueryIds.size()]; + BitSet ngramsMask = new BitSet(compiledQuery.size()); + for (int idx = 0; idx < compiledQueryIds.size(); idx++) { long id = compiledQueryIds.at(idx); full[idx] = index.getTermFrequency(id); prio[idx] = index.getTermFrequencyPrio(id); + + if (compiledQuery.at(idx).contains("_")) { + ngramsMask.set(idx); + } } return new ResultRankingContext(index.getTotalDocCount(), rankingParams, + ngramsMask, new CqDataInt(full), new CqDataInt(prio)); } diff --git a/code/index/java/nu/marginalia/ranking/results/ResultValuator.java b/code/index/java/nu/marginalia/ranking/results/ResultValuator.java index 4d257349..d233651b 100644 --- a/code/index/java/nu/marginalia/ranking/results/ResultValuator.java +++ b/code/index/java/nu/marginalia/ranking/results/ResultValuator.java @@ -76,7 +76,8 @@ public class ResultValuator { double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta); - double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(new Bm25FullGraphVisitor(rankingParams.fullParams, wordMeta.data, length, ctx)); + double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forRegular(rankingParams.fullParams, wordMeta.data, length, ctx)); + double bestBM25N = 0.25 * rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forNgrams(rankingParams.fullParams, wordMeta.data, length, ctx)); double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx)); double overallPartPositive = Math.max(0, overallPart); @@ -84,7 +85,7 @@ public class ResultValuator { // Renormalize to 0...15, where 0 is the best possible score; // this is a historical artifact of the original ranking function - return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + overallPartPositive, overallPartNegative); + return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + bestBM25N + overallPartPositive, overallPartNegative); } private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) { diff --git a/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java b/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java index 9c46261d..4105ed6b 100644 --- a/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java +++ b/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java @@ -7,6 +7,7 @@ import nu.marginalia.api.searchquery.model.results.Bm25Parameters; import nu.marginalia.api.searchquery.model.results.ResultRankingContext; import nu.marginalia.model.idx.WordMetadata; +import java.util.BitSet; import java.util.List; public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor { @@ -19,15 +20,33 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor { private final int docCount; private final int length; - public Bm25FullGraphVisitor(Bm25Parameters bm25Parameters, + private final BitSet mask; + + private Bm25FullGraphVisitor(Bm25Parameters bm25Parameters, CqDataLong wordMetaData, int length, + BitSet mask, ResultRankingContext ctx) { this.length = length; this.bm25Parameters = bm25Parameters; this.docCount = ctx.termFreqDocCount(); this.wordMetaData = wordMetaData; this.frequencies = ctx.fullCounts; + this.mask = mask; + } + + public static Bm25FullGraphVisitor forRegular(Bm25Parameters bm25Parameters, + CqDataLong wordMetaData, + int length, + ResultRankingContext ctx) { + return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.regularMask, ctx); + } + + public static Bm25FullGraphVisitor forNgrams(Bm25Parameters bm25Parameters, + CqDataLong wordMetaData, + int length, + ResultRankingContext ctx) { + return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.ngramsMask, ctx); } @Override @@ -50,6 +69,10 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor { @Override public double onLeaf(int idx) { + if (!mask.get(idx)) { + return 0; + } + double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx))); int freq = frequencies.get(idx);