(index) Split ngram and regular keyword bm25 calculation and add ngram score as a bonus

This commit is contained in:
Viktor Lofgren 2024-04-17 14:04:35 +02:00
parent 579295a673
commit f52457213e
4 changed files with 52 additions and 4 deletions

View File

@ -3,11 +3,17 @@ package nu.marginalia.api.searchquery.model.results;
import lombok.ToString;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import java.util.BitSet;
@ToString
public class ResultRankingContext {
private final int docCount;
public final ResultRankingParameters params;
public final BitSet regularMask;
public final BitSet ngramsMask;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt fullCounts;
@ -18,11 +24,18 @@ public class ResultRankingContext {
public ResultRankingContext(int docCount,
ResultRankingParameters params,
BitSet ngramsMask,
CqDataInt fullCounts,
CqDataInt prioCounts)
{
this.docCount = docCount;
this.params = params;
this.ngramsMask = ngramsMask;
this.regularMask = new BitSet(ngramsMask.length());
this.regularMask.xor(ngramsMask);
this.fullCounts = fullCounts;
this.priorityCounts = prioCounts;
}

View File

@ -9,6 +9,7 @@ import io.prometheus.client.Histogram;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.SneakyThrows;
import nu.marginalia.api.searchquery.*;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
@ -204,7 +205,9 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
return new SearchResultSet(List.of());
}
ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.compiledQueryIds);
ResultRankingContext rankingContext = createRankingContext(params.rankingParams,
params.compiledQuery,
params.compiledQueryIds);
var queryExecution = new QueryExecution(rankingContext, params.fetchSize);
@ -415,20 +418,28 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
}
private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams,
CompiledQuery<String> compiledQuery,
CompiledQueryLong compiledQueryIds)
{
int[] full = new int[compiledQueryIds.size()];
int[] prio = new int[compiledQueryIds.size()];
BitSet ngramsMask = new BitSet(compiledQuery.size());
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
long id = compiledQueryIds.at(idx);
full[idx] = index.getTermFrequency(id);
prio[idx] = index.getTermFrequencyPrio(id);
if (compiledQuery.at(idx).contains("_")) {
ngramsMask.set(idx);
}
}
return new ResultRankingContext(index.getTotalDocCount(),
rankingParams,
ngramsMask,
new CqDataInt(full),
new CqDataInt(prio));
}

View File

@ -76,7 +76,8 @@ public class ResultValuator {
double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta);
double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(new Bm25FullGraphVisitor(rankingParams.fullParams, wordMeta.data, length, ctx));
double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forRegular(rankingParams.fullParams, wordMeta.data, length, ctx));
double bestBM25N = 0.25 * rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forNgrams(rankingParams.fullParams, wordMeta.data, length, ctx));
double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx));
double overallPartPositive = Math.max(0, overallPart);
@ -84,7 +85,7 @@ public class ResultValuator {
// Renormalize to 0...15, where 0 is the best possible score;
// this is a historical artifact of the original ranking function
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + overallPartPositive, overallPartNegative);
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + bestBM25N + overallPartPositive, overallPartNegative);
}
private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) {

View File

@ -7,6 +7,7 @@ import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.model.idx.WordMetadata;
import java.util.BitSet;
import java.util.List;
public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
@ -19,15 +20,33 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
private final int docCount;
private final int length;
public Bm25FullGraphVisitor(Bm25Parameters bm25Parameters,
private final BitSet mask;
private Bm25FullGraphVisitor(Bm25Parameters bm25Parameters,
CqDataLong wordMetaData,
int length,
BitSet mask,
ResultRankingContext ctx) {
this.length = length;
this.bm25Parameters = bm25Parameters;
this.docCount = ctx.termFreqDocCount();
this.wordMetaData = wordMetaData;
this.frequencies = ctx.fullCounts;
this.mask = mask;
}
public static Bm25FullGraphVisitor forRegular(Bm25Parameters bm25Parameters,
CqDataLong wordMetaData,
int length,
ResultRankingContext ctx) {
return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.regularMask, ctx);
}
public static Bm25FullGraphVisitor forNgrams(Bm25Parameters bm25Parameters,
CqDataLong wordMetaData,
int length,
ResultRankingContext ctx) {
return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.ngramsMask, ctx);
}
@Override
@ -50,6 +69,10 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
@Override
public double onLeaf(int idx) {
if (!mask.get(idx)) {
return 0;
}
double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx)));
int freq = frequencies.get(idx);