mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-02-23 21:18:58 +00:00
(index) Split ngram and regular keyword bm25 calculation and add ngram score as a bonus
This commit is contained in:
parent
579295a673
commit
f52457213e
@ -3,11 +3,17 @@ package nu.marginalia.api.searchquery.model.results;
|
|||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
|
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
|
||||||
|
|
||||||
|
import java.util.BitSet;
|
||||||
|
|
||||||
@ToString
|
@ToString
|
||||||
public class ResultRankingContext {
|
public class ResultRankingContext {
|
||||||
private final int docCount;
|
private final int docCount;
|
||||||
public final ResultRankingParameters params;
|
public final ResultRankingParameters params;
|
||||||
|
|
||||||
|
|
||||||
|
public final BitSet regularMask;
|
||||||
|
public final BitSet ngramsMask;
|
||||||
|
|
||||||
/** CqDataInt associated with frequency information of the terms in the query
|
/** CqDataInt associated with frequency information of the terms in the query
|
||||||
* in the full index. The dataset is indexed by the compiled query. */
|
* in the full index. The dataset is indexed by the compiled query. */
|
||||||
public final CqDataInt fullCounts;
|
public final CqDataInt fullCounts;
|
||||||
@ -18,11 +24,18 @@ public class ResultRankingContext {
|
|||||||
|
|
||||||
public ResultRankingContext(int docCount,
|
public ResultRankingContext(int docCount,
|
||||||
ResultRankingParameters params,
|
ResultRankingParameters params,
|
||||||
|
BitSet ngramsMask,
|
||||||
CqDataInt fullCounts,
|
CqDataInt fullCounts,
|
||||||
CqDataInt prioCounts)
|
CqDataInt prioCounts)
|
||||||
{
|
{
|
||||||
this.docCount = docCount;
|
this.docCount = docCount;
|
||||||
this.params = params;
|
this.params = params;
|
||||||
|
|
||||||
|
this.ngramsMask = ngramsMask;
|
||||||
|
|
||||||
|
this.regularMask = new BitSet(ngramsMask.length());
|
||||||
|
this.regularMask.xor(ngramsMask);
|
||||||
|
|
||||||
this.fullCounts = fullCounts;
|
this.fullCounts = fullCounts;
|
||||||
this.priorityCounts = prioCounts;
|
this.priorityCounts = prioCounts;
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import io.prometheus.client.Histogram;
|
|||||||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import nu.marginalia.api.searchquery.*;
|
import nu.marginalia.api.searchquery.*;
|
||||||
|
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
|
||||||
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
|
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
|
||||||
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
|
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
|
||||||
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
|
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
|
||||||
@ -204,7 +205,9 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
|
|||||||
return new SearchResultSet(List.of());
|
return new SearchResultSet(List.of());
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.compiledQueryIds);
|
ResultRankingContext rankingContext = createRankingContext(params.rankingParams,
|
||||||
|
params.compiledQuery,
|
||||||
|
params.compiledQueryIds);
|
||||||
|
|
||||||
var queryExecution = new QueryExecution(rankingContext, params.fetchSize);
|
var queryExecution = new QueryExecution(rankingContext, params.fetchSize);
|
||||||
|
|
||||||
@ -415,20 +418,28 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams,
|
private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams,
|
||||||
|
CompiledQuery<String> compiledQuery,
|
||||||
CompiledQueryLong compiledQueryIds)
|
CompiledQueryLong compiledQueryIds)
|
||||||
{
|
{
|
||||||
|
|
||||||
int[] full = new int[compiledQueryIds.size()];
|
int[] full = new int[compiledQueryIds.size()];
|
||||||
int[] prio = new int[compiledQueryIds.size()];
|
int[] prio = new int[compiledQueryIds.size()];
|
||||||
|
|
||||||
|
BitSet ngramsMask = new BitSet(compiledQuery.size());
|
||||||
|
|
||||||
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
|
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
|
||||||
long id = compiledQueryIds.at(idx);
|
long id = compiledQueryIds.at(idx);
|
||||||
full[idx] = index.getTermFrequency(id);
|
full[idx] = index.getTermFrequency(id);
|
||||||
prio[idx] = index.getTermFrequencyPrio(id);
|
prio[idx] = index.getTermFrequencyPrio(id);
|
||||||
|
|
||||||
|
if (compiledQuery.at(idx).contains("_")) {
|
||||||
|
ngramsMask.set(idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ResultRankingContext(index.getTotalDocCount(),
|
return new ResultRankingContext(index.getTotalDocCount(),
|
||||||
rankingParams,
|
rankingParams,
|
||||||
|
ngramsMask,
|
||||||
new CqDataInt(full),
|
new CqDataInt(full),
|
||||||
new CqDataInt(prio));
|
new CqDataInt(prio));
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,8 @@ public class ResultValuator {
|
|||||||
|
|
||||||
double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta);
|
double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta);
|
||||||
|
|
||||||
double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(new Bm25FullGraphVisitor(rankingParams.fullParams, wordMeta.data, length, ctx));
|
double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forRegular(rankingParams.fullParams, wordMeta.data, length, ctx));
|
||||||
|
double bestBM25N = 0.25 * rankingParams.bm25FullWeight * wordMeta.root.visit(Bm25FullGraphVisitor.forNgrams(rankingParams.fullParams, wordMeta.data, length, ctx));
|
||||||
double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx));
|
double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx));
|
||||||
|
|
||||||
double overallPartPositive = Math.max(0, overallPart);
|
double overallPartPositive = Math.max(0, overallPart);
|
||||||
@ -84,7 +85,7 @@ public class ResultValuator {
|
|||||||
|
|
||||||
// Renormalize to 0...15, where 0 is the best possible score;
|
// Renormalize to 0...15, where 0 is the best possible score;
|
||||||
// this is a historical artifact of the original ranking function
|
// this is a historical artifact of the original ranking function
|
||||||
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + overallPartPositive, overallPartNegative);
|
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + bestBM25N + overallPartPositive, overallPartNegative);
|
||||||
}
|
}
|
||||||
|
|
||||||
private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) {
|
private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) {
|
||||||
|
@ -7,6 +7,7 @@ import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
|
|||||||
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
|
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
|
||||||
import nu.marginalia.model.idx.WordMetadata;
|
import nu.marginalia.model.idx.WordMetadata;
|
||||||
|
|
||||||
|
import java.util.BitSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
|
public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
|
||||||
@ -19,15 +20,33 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
|
|||||||
private final int docCount;
|
private final int docCount;
|
||||||
private final int length;
|
private final int length;
|
||||||
|
|
||||||
public Bm25FullGraphVisitor(Bm25Parameters bm25Parameters,
|
private final BitSet mask;
|
||||||
|
|
||||||
|
private Bm25FullGraphVisitor(Bm25Parameters bm25Parameters,
|
||||||
CqDataLong wordMetaData,
|
CqDataLong wordMetaData,
|
||||||
int length,
|
int length,
|
||||||
|
BitSet mask,
|
||||||
ResultRankingContext ctx) {
|
ResultRankingContext ctx) {
|
||||||
this.length = length;
|
this.length = length;
|
||||||
this.bm25Parameters = bm25Parameters;
|
this.bm25Parameters = bm25Parameters;
|
||||||
this.docCount = ctx.termFreqDocCount();
|
this.docCount = ctx.termFreqDocCount();
|
||||||
this.wordMetaData = wordMetaData;
|
this.wordMetaData = wordMetaData;
|
||||||
this.frequencies = ctx.fullCounts;
|
this.frequencies = ctx.fullCounts;
|
||||||
|
this.mask = mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Bm25FullGraphVisitor forRegular(Bm25Parameters bm25Parameters,
|
||||||
|
CqDataLong wordMetaData,
|
||||||
|
int length,
|
||||||
|
ResultRankingContext ctx) {
|
||||||
|
return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.regularMask, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Bm25FullGraphVisitor forNgrams(Bm25Parameters bm25Parameters,
|
||||||
|
CqDataLong wordMetaData,
|
||||||
|
int length,
|
||||||
|
ResultRankingContext ctx) {
|
||||||
|
return new Bm25FullGraphVisitor(bm25Parameters, wordMetaData, length, ctx.ngramsMask, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -50,6 +69,10 @@ public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double onLeaf(int idx) {
|
public double onLeaf(int idx) {
|
||||||
|
if (!mask.get(idx)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx)));
|
double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx)));
|
||||||
|
|
||||||
int freq = frequencies.get(idx);
|
int freq = frequencies.get(idx);
|
||||||
|
Loading…
Reference in New Issue
Block a user